#!/usr/bin/env python
#
# file: $NEDC_NFC/src/nedc_dpath_eb0_decode/nedc_dpath_eb0_decode.py
#
# revision history:
#

# 20250126 (AM): initial version
#
# This script can decode a set of Image (svs, or pil/cucim supported) files
# with a given Efficientnet-B0 model
#------------------------------------------------------------------------------

# import system modules except torch
#
import os
from pathlib import Path
import sys
import time

# import NEDC modules
#
import nedc_cmdl_parser as ncp
import nedc_debug_tools as ndt
import nedc_file_tools as nft
import nedc_dpath_ann_tools as nda
import nedc_dpath_ml_eb0_tools as ndr
import nedc_image_tools as nit

# setup and clean temporary file directories before
# torch is imported
#
tempDirManager = nft.TempDirManager("dpath_eb0")
tempDirManager.create()

# import torch modules
#
import torch
import torch.backends.cudnn as cudnn

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

# set the filename using basename
#
__FILE__ = os.path.basename(__file__)

# define a file version
#
NEDC_VERSION = ndr.NEDC_DPATH_EB0_VERSION

# define a log file name
#
NEDC_SUMMARY_FILE = "eb0_decode.log"

# define the location of the help file
#
HELP_FILE = \
 "$NEDC_NFC/docs/help/nedc_dpath_eb0_decode.help"

# define the location of the usage file
#
USAGE_FILE = \
 "$NEDC_NFC/docs/usage/nedc_dpath_eb0_decode.usage"

# define the command line argument options:
#
ARG_ODIR = ncp.ARG_ODIR
ARG_ABRV_ODIR = ncp.ARG_ABRV_ODIR

ARG_RDIR = ncp.ARG_RDIR
ARG_ABRV_RDIR = ncp.ARG_ABRV_RDIR

ARG_PARM = "--pfile"
ARG_ABRV_PARM = "-p"

ARG_OEXT = ncp.ARG_OEXT
ARG_ABRV_OEXT = ncp.ARG_ABRV_OEXT

# define the default parameter file path:
#
DEF_PFILE = "$NEDC_NFC/docs/params/nedc_dpath_eb0_decode_params_v00.toml"

# define the default output directory
#
DEF_ODIR = "./output"

# define the default output file extension
#
DEF_OEXIT = nft.DEF_EXT_CSV

# define keys to access the parameter files
# parameters
#
EB0_DECODE = "EB0_DECODE"
DECODE_PARAM_KEY_BATCH_SIZE = "batch_size"
DECODE_PARAM_KEY_COMP_DEVICE = "device"
DECODE_PARAM_KEY_INPUT_WEIGHTS_PATH = "input_weights_path"
DECODE_PARAM_KEY_INPUT_MODEL_PATH = "input_model_path"
DECODE_PARAM_KEY_NUM_THREADS = "num_threads"
DECODE_PARAM_KEY_NUM_WORKERS = "num_workers"
DECODE_PARAM_KEY_FRMSIZE = "frmsize"
DECODE_PARAM_KEY_RESIZE_VAL = "resize_val"
DECODE_PARAM_KEY_WINDOW_SIZE = "window_size"
DECODE_PARAM_KEY_FRAME_PTHRESH = "frame_pthresh"
DECODE_PARAM_KEY_PRE_PROCESS_ALGORITHM = "pre_process_algorithm"

# define the device utilized by the decoder
#
DEVICE_CPU = ndr.DEVICE_CPU
DEVICE_GPU = ndr.DEVICE_GPU
DEVICE_CUDA = ndr.DEVICE_CUDA

# declare a global debug object so we can use it in functions
#
dbgl = ndt.Dbgl()

#------------------------------------------------------------------------------
#
# functions are listed here
#
#------------------------------------------------------------------------------

def main(argv):
    """
    function: main 

    arguments: command line

    return: boolean value indicating status

    description:
    This function is the main function
    """

    # create a command line parser
    #
    cmdl = ncp.Cmdl(USAGE_FILE, HELP_FILE)

    # define the command line arguments
    #
    cmdl.add_argument("files", type = str, nargs = "*")
    cmdl.add_argument(ARG_ABRV_PARM, ARG_PARM, type = str)
    cmdl.add_argument(ARG_ABRV_ODIR, ARG_ODIR, type = str)
    cmdl.add_argument(ARG_ABRV_RDIR, ARG_RDIR, type = str)
    cmdl.add_argument(ARG_ABRV_OEXT, ARG_OEXT, type = str)

    # parse the command line
    #
    args = cmdl.parse_args()

    # make sure there is a file argument
    #
    if len(args.files) == 0:
        cmdl.print_usage()
        sys.exit(os.EX_SOFTWARE)

    # process the command line arguments
    #
    if args.ext is None:
        args.ext = DEF_OEXIT
    if args.rdir is None:
        args.rdir = ncp.DEF_RDIR
    if args.odir is None:
        args.odir = DEF_ODIR
    if args.pfile is None:
        args.pfile = nft.get_fullpath(DEF_PFILE)

    # display debug information
    #
    if dbgl > ndt.NONE:
        print("command line arguments:")
        print(f" file extension = {args.ext}")
        print(f" output directory = {args.odir}")
        print(f" replace directory = {args.rdir}")
        print(f" parameter file = {args.pfile}")
        print(f" argument files = {args.files}")

    # fetch parameters from command line
    #
    param_fname, odir, rdir, oext = \
        (args.pfile, args.odir, args.rdir, args.ext)

    # load the parameter file
    #
    params = nft.load_parameters(param_fname, EB0_DECODE)

    # fetch the model path/name
    #
    in_weights_fname = params[DECODE_PARAM_KEY_INPUT_WEIGHTS_PATH]

    # make sure the weight file is specified
    #
    if in_weights_fname is None:
        print("Error: %s (line: %s) %s: must specify a weight file" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__))
        sys.exit(os.EX_SOFTWARE)

    # get the weight files full path
    #
    in_weights_path = nft.get_fullpath(in_weights_fname)

    # ensure weight file exists
    #
    if os.path.isfile(in_weights_path) == False:
        print("Error: %s (line: %s) %s: weight file doesn't exist (%s)" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__, in_weights_path))
        sys.exit(os.EX_SOFTWARE)

    # fetch the model (obj) input file (optional)
    #
    in_model_file = params[DECODE_PARAM_KEY_INPUT_MODEL_PATH]

    # expand input model path environment variables
    #
    in_model_path = nft.get_fullpath(in_model_file)

    # if the input model path does not exist set it to none
    # (i.e. don't decode based off this model object)
    #
    if os.path.isfile(in_model_path) == False:
        in_model_path = None

    # get the preprocessing algorithm name
    #
    pre_proc_alg = params[DECODE_PARAM_KEY_PRE_PROCESS_ALGORITHM]
    
    # get values that need to be processed
    #
    resize_val = int(params[DECODE_PARAM_KEY_RESIZE_VAL])

    # calculate the amount of samples needed per frame
    #
    frmsize = int(params[DECODE_PARAM_KEY_FRMSIZE])

    # get the window size
    #
    window_size = int(params[DECODE_PARAM_KEY_WINDOW_SIZE])

    # get the number of workers, batch size, and threads
    #
    num_workers = int(params[DECODE_PARAM_KEY_NUM_WORKERS])
    batch_size = int(params[DECODE_PARAM_KEY_BATCH_SIZE])
    num_threads = int(params[DECODE_PARAM_KEY_NUM_THREADS])

    # get the user specified threshold
    #
    frame_pthresh = float(params[DECODE_PARAM_KEY_FRAME_PTHRESH])
    
    # set torch device information
    #
    device = torch.device(params[DECODE_PARAM_KEY_COMP_DEVICE]
                          if torch.cuda.is_available() else DEVICE_CPU)
    
    print(f'using device: {device}')

    # create a EB0 Decode object
    #
    eb0_decode = ndr.DPATHEB0Decode(
        window_size, frmsize, device, num_workers,
        num_threads, in_model_path, in_weights_path,
        batch_size, frame_pthresh, pre_proc_alg
    )

    # create elapsed variable to hold the elapsed time
    #
    elapsed = float(0.0)

    # keep track of the start time
    #
    init_time = time.time()

    # main processing loop: loop over all input filenames
    #
    num_files_att = int(0)
    num_files_proc = int(0)

    # initialize device name to CPU as default
    #
    device_name = DEVICE_CPU

    # fetch the device name if the device is a GPU
    #
    if device.type == DEVICE_CUDA:
        device_name = torch.cuda.get_device_name()

    # make the output directory
    #
    os.makedirs(args.odir, exist_ok=True)

    # create a file pointer for logging
    #
    logfile_fname = nft.concat_names(args.odir, NEDC_SUMMARY_FILE)
    fp = nft.make_fp(logfile_fname)

    # print a log message
    #
    fp.write("%s" % (dbgl.log(__file__, NEDC_VERSION) + nft.DELIM_NEWLINE))
    fp.write("DEVICE : %s\n" % device_name)
    fp.write("MODEL OBJECTS PATH : %s\n" % in_model_path)
    fp.write("MODEL WEIGHTS PATH : %s\n" % in_weights_path)
    fp.write("BATCH SIZE : %s\n" % batch_size)
    fp.write("NUM WORKERS : %s\n" % num_workers)
    fp.write("NUM THREADS : %s\n" % num_threads)
    fp.write("FRAME SIZE : %s\n" % frmsize)
    fp.write("WINDOW SIZE : %s\n" % window_size)
    fp.write("DECODING SEED : %s\n" % ndt.RANDSEED)

    # display an informational message
    #
    print("beginning decoding...")

    # process each file argument
    #
    for arg_file in args.files:

        # fetch arg_files full path
        #
        ffile = nft.get_fullpath(arg_file)

        # check if the argument file exists
        #
        if os.path.isfile(ffile) == False:
            print("Error: %s (line: %s) %s: file doesn't exist (%s)" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__, ffile))
            sys.exit(os.EX_SOFTWARE)

        # check if the file is a list file or Image file
        #
        if nit.Nil().is_image(ffile):

            # display debug information
            #
            if dbgl > ndt.NONE:
                print("%s (line: %s) %s: decoding image file (%s)" %
                      (__FILE__, ndt.__LINE__, ndt.__NAME__, ffile))

            # display information
            #
            num_files_att += int(1)
            print("%3ld: %s" % (num_files_att, ffile))

            # write info to log file
            #
            fp.write("%3ld: %s" % (num_files_att, ffile))
            
            # decode the file
            #
            ann_graph = eb0_decode.decode(ffile, fp)

            # if ann graph is not none continue
            #
            if ann_graph != None:

                # create output filename
                #
                ofile = nft.create_filename(ffile, odir, oext, rdir)

                # create header and save annotation to csv file
                #
                nda.write_data_to_file(
                    ann_graph, eb0_decode.preprocessor.impl.img_width,
                    eb0_decode.preprocessor.impl.img_height, ofile
                )

                num_files_proc += int(1)
            else:
                print("Error: %s (line: %s) %s: %s (%s)" %
                      (__FILE__, ndt.__LINE__, ndt.__NAME__,
                       "error decoding", ffile))

        # if the file isn't an image file assume it is a file list
        #
        else:

            # display debug information
            #
            if dbgl > ndt.NONE:
                print("%s (line: %s) %s: opening list (%s)" %
                      (__FILE__, ndt.__LINE__, ndt.__NAME__, ffile))

            # fetch file list
            #
            flist = nft.get_flist(ffile)

            # ensure the flist method worked
            #
            if flist is None:
                print("Error: %s (line: %s) %s: %s (%s)"
                      % (__FILE__, ndt.__LINE__, ndt.__NAME__,
                         "error retrieving file list", flist))
                sys.exit(os.EX_SOFTWARE)

            # expand environment variables of each file in flist and
            # ensure it exists
            #
            for fname in flist:

                # fetch the full path
                #
                ffile = nft.get_fullpath(fname)

                # check if the argument file exists
                #
                if os.path.isfile(ffile) == False:
                    print("Error: %s (line: %s) %s: file doesn't exist (%s)" %
                          (__FILE__, ndt.__LINE__, ndt.__NAME__, ffile))
                    sys.exit(os.EX_SOFTWARE)

                # display debug information
                #
                if dbgl > ndt.NONE:
                    print("%s (line: %s) %s: decoding image file list (%s)" %
                          (__FILE__, ndt.__LINE__, ndt.__NAME__, ffile))

                # display information
                #
                num_files_att += int(1)
                print("%3ld: %s" % (num_files_att, ffile))

                # write files attached to log file
                #
                fp.write("%3ld: %s\n" % (num_files_att, ffile))

                # decode the file
                #
                ann_graph = eb0_decode.decode(ffile, fp)

                # if ann graph is not none continue
                #
                if ann_graph != None:

                    # create output filename
                    #
                    ofile = nft.create_filename(ffile, odir, oext, rdir)

                    # create header and save annotation to csv file
                    #
                    nda.write_data_to_file(
                        ann_graph, eb0_decode.preprocessor.impl.img_width,
                        eb0_decode.preprocessor.impl.img_height, ofile
                    )

                    num_files_proc += int(1)

                else:
                    print("Error: %s (line: %s) %s: %s (%s)" %
                          (__FILE__, ndt.__LINE__, ndt.__NAME__,
                           "error decoding", ffile))


    # calculate the elapsed time
    #
    elapsed = time.time() - init_time

    # display the results
    #
    print("processed %ld out of %ld files successfully" %
          (num_files_proc, num_files_att))

    # display the finished message
    #
    print(f'decoding of all files finished in {elapsed:.0f} seconds.')

    # write results to log
    #
    fp.write(f'decoding of all files finished in {elapsed:.0f} seconds.')

    # close the file pointer
    #
    fp.close()

    # exit gracefully
    #
    return True
#
# end of function

# begin gracefully
#
if __name__ == '__main__':
    main(sys.argv)
#
# end of file
