#!/usr/bin/env python
#
# file: $(NEDC_NFC)/tool/nedc_dpath_resnet_extract_patch/nedc_dpath_resnet_extract_patch.py
#
# revision history:
#
# 20230929 (SM): prep for v2.0.0 release
# 20211220 (PM): refactored code
# 20210324 (VK): make it compatible with NEDC command line tools and param files
# 20210315 (VK): extraction will be done before distribution of patches
# 20210313 (VK): shuffeling capability plus weighted slide selection
# 20210127 (VK): second version
# 20200904 (VK): initial version
#
# This a Python script that extract the square-shape patches with predefined sizes 
# from the input svs slide files.
#------------------------------------------------------------------------------

# import system modules
#
import os
import pickle
import sys
from datetime import datetime

# import NEDC support modules
#
import nedc_cmdl_parser as ncp
import nedc_debug_tools as ndt
import nedc_file_tools as nft
import nedc_dpath_extract_patch as ndep

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

EXTRACT_PATCH = 'EXTRACT_PATCH'
TRAIN, DEV, EVAL = 'train', 'dev', 'eval'
DEF_PARAM_FILE = 'extract_patch_param'
IMG_EXT = '.tif'

# set the filename using basename
#
__FILE__ = os.path.basename(__file__)

# define script location
#
SCRIPT_LOC = os.path.dirname(os.path.realpath(__file__))
 
# define the help file and usage message
#
HELP_FILE = \
"$NEDC_NFC/src/python/util/nedc_dpath_resnet_extract_patch/nedc_dpath_resnet_extract_patch.help"

USAGE_FILE = \
"$NEDC_NFC/src/python/util/nedc_dpath_resnet_extract_patch/nedc_dpath_resnet_extract_patch.usage"

# define the program options
#  note that you cannot separate them by spaces
#
ARG_ODIR = "--odir"
ARG_ABRV_ODIR = "-o"

ARG_PARM = "--parameters"
ARG_ABRV_PARM = "-p"

ARG_PATCH_DIR = "--patch_dir"
ARG_ABRV_PATCH_DIR = "-pd"

ARG_LOG_DIR = "--log_dir"
ARG_ABRV_LOG_DIR = "-l"

# define default values for arguments:
#  note we assume the parameter file is in the same
#  directory as the source code.
#
DEF_PFILE = "./nedc_dpath_resnet_extract_patch_param.txt"
DEF_ODIR = "./output"
DEF_PDIR = os.path.join(DEF_ODIR, "patch")

# define the names of the output files
#
STAT_FNAME = 'stats.pckl'

# define the required number of arguments
#
NUM_ARGS = 1

#------------------------------------------------------------------------------
#
# functions are listed here
#
#------------------------------------------------------------------------------

# declare a global debug object
#
dbgl = ndt.Dbgl()

# function: main
#
def main(argv):
    
    # declare default values for command line arguments
    #
    odir = nft.get_fullpath(DEF_ODIR)

    # create a command line parser
    #
    cmdl = ncp.Cmdl(USAGE_FILE, HELP_FILE)

    # define the command line arguments
    #
    cmdl.add_argument("files", type = str, nargs = '*')
    cmdl.add_argument(ARG_ABRV_ODIR, ARG_ODIR, type = str)
    cmdl.add_argument(ARG_ABRV_PARM, ARG_PARM, type = str)
    cmdl.add_argument(ARG_ABRV_PATCH_DIR, ARG_PATCH_DIR, type = str)
    cmdl.add_argument(ARG_LOG_DIR, ARG_ABRV_LOG_DIR, type = str)

    # parse the command line
    #
    args = cmdl.parse_args()

    # check if the proper number of lists has been provided
    #
    if len(args.files) != NUM_ARGS:
        cmdl.print_usage('stdout')
        sys.exit(os.EX_SOFTWARE)
    
    # process the command line arguments
    #
    odir, pfile, patch_dir, log_dir, flists = \
        (args.odir, args.parameters, args.patch_dir, args.log_dir, args.files)
    
    # set argument values
    #
    if odir == None:
        odir = DEF_ODIR
    if pfile == None:
        pfile = DEF_PARAM_FILE
    if patch_dir == None:
        patch_dir = DEF_PDIR
    
    # if debug is on, print debug info
    #
    if dbgl > ndt.NONE:
        print("command line arguments:")
        print(" output directory = %s" % (odir))
        print(" parameter file = %s" % (pfile))
        print(" patch directory = %d" % (patch_dir))
        print("")

    # make directories
    #
    odir = os.path.abspath(odir)
    os.makedirs(odir, exist_ok=True)
    os.makedirs(patch_dir, exist_ok=True)

    # read parameters
    #
    params = nft.load_parameters(pfile, EXTRACT_PATCH)

    # create a patch directory for all data sets
    #
    patch_dir = os.path.realpath(os.path.join(odir, patch_dir))
    os.makedirs(patch_dir, exist_ok=True)
    
    # extract patches in accordance to each individual file list
    #
    print_time = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
    print(f"Extract patch starting at {print_time}\n")
    for flist in flists:

        # set the start time for the current list
        #
        print_time = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")

        # get the basename of the file list (train, dev, eval, ...)
        #
        flist_bname = os.path.basename(flist).split(os.path.extsep)[0]
        print(f"\tExtract patch for {flist_bname} started at {print_time}\n")

        # make an instance of ExtactPatch and run its main functionality
        #
        extp = ndep.ExtractPatch(params, flist)

        # make patch directory for the current file list
        #
        set_dir = os.path.realpath(os.path.join(patch_dir, flist_bname))
        os.makedirs(set_dir, exist_ok=True)

        # extract patches and return currsponding information
        #
        nps = int(params['nps'])
        state_dict, patches_list = \
            extp.extract_patch(set_dir, extp.acc_lbls)
    
        # save state_dict
        #
        if log_dir != None:

            # make a log directory for the current set and save the state dict
            #
            ldir = os.path.join(log_dir, flist_bname)
            os.makedirs(ldir, exist_ok=True)
            extp.save_state_dict(state_dict, ldir)
            
            # save list files
            # patches list for every label individually
            #
            extp.save_labels_patches(patches_list, ldir)
            pfile = os.path.join(ldir, STAT_FNAME)
            with open(pfile, mode="wb") as pf:

                # open a file
                #
                if dbgl > ndt.BRIEF:
                    print("%s (line: %s) %s: opening (%s)\n" %
                    (__FILE__, ndt.__LINE__, ndt.__NAME__, pfile))

                pf = open(pfile, mode='wb')
                if pf is None:
                    print("Error: %s (line: %s) %s: error opening file (%s)" %
                    (__FILE__, ndt.__LINE__, ndt.__NAME__, pfile))
                    return False

                # write the patch into a file
                #
                pickle.dump(patches_list, pf)

                # close the file
                #
                pf.close()

        # print the total patching time for the specific list
        #
        print_time = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
        print(f"\t{flist_bname} files finished at {print_time}\n")

    # print the total patching time for all the lists
    #
    print_time = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
    print(f"All files finished in {print_time} seconds")
#
# end of main

# begin gracefully
#
if __name__ == "__main__":
    main(sys.argv[0:])

#
# end of file
