#!/usr/bin/env python
#
# file: $(NEDC_NFC)/util/python/nedc_eeg_resnet_decode
#
# revision history:
#
# 20220703 (ML): initial version
# 20220705 (PM): code review
#
# This script can decode a set of EDF or PCKL files
# with a given resnet 18 model
#------------------------------------------------------------------------------

# import system modules
#
import os
import sys
import torch
from torchvision import transforms

# import NEDC modules
#
import nedc_file_tools as nft
import nedc_cmdl_parser as ncp
import nedc_debug_tools as ndt
import nedc_resnet_decode as nrd
import nedc_edf_tools as net

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

# set the filename using basename
#
__FILE__ = os.path.basename(__file__)

# define the location of the help files
#
HELP_FILE = \
 "$NEDC_NFC/src/python/util/nedc_eeg_resnet_decode/nedc_eeg_resnet_decode.help"

USAGE_FILE = \
 "$NEDC_NFC/src/python/util/nedc_eeg_resnet_decode/nedc_eeg_resnet_decode.usage"

# define the program options:
#
ARG_MODEL = "--model"
ARG_ABRV_MODEL = "-m"

ARG_ODIR = ncp.ARG_ODIR
ARG_ABRV_ODIR = ncp.ARG_ABRV_ODIR

ARG_RDIR = ncp.ARG_RDIR
ARG_ABRV_RDIR = ncp.ARG_ABRV_RDIR

ARG_PARM = "--parameters"
ARG_ABRV_PARM = "-p"

# define default values for arguments:
#
DEF_PARAM_FILE = \
 "$NEDC_NFC/lib/nedc_eeg_resnet_decode_param.txt"

DEF_PRETRAINED_MODEL = "$NEDC_NFC/models/resnet-18_pretrained.pth"

DEF_ODIR = "$NEDC_NFC/test/output"

# decoding parameters
#
RESNET_DECODE = 'RESNET_DECODE'
SAMPLE_FREQUENCY = 'sample_frequency'
FILTER_ORDER = 'filter_order'
CH_ORDER = 'channel_order'
TRANSFORMS = 'transforms'
FRM_LEN = 'frm_len'
DEVICE_STR = 'device'
CPU_STR = 'cpu'
SAMP_FREQ = 'samp_freq'
SEIZ_TH = 'seizure_threshold'
MIN_BCKG = 'minimum_background_duration'
MIN_SEIZ = 'minimum_seizure_duration'
FRAME_DURATION = 'frame_duration'
TRANSFORM_CROP = 'transform_crop'

# set the required number of args
#
NUM_ARGS = 1

# declare a global debug object so we can use it in functions
#
dbgl = ndt.Dbgl()

# function: main
#
# arguments: none
#
# return: none
#
# This method is the main function
#
def main(argv):

    # create a command line parser
    #
    cmdl = ncp.Cmdl(USAGE_FILE, HELP_FILE)
    
    # define the command line arguments
    #
    cmdl.add_argument("files", type = str, nargs = '*')
    cmdl.add_argument(ARG_ABRV_PARM, ARG_PARM, type = str)
    cmdl.add_argument(ARG_ABRV_ODIR, ARG_ODIR, type = str)
    cmdl.add_argument(ARG_ABRV_MODEL, ARG_MODEL, type = str)
    cmdl.add_argument(ARG_ABRV_RDIR, ARG_RDIR, type = str)
    
    # parse the command line
    #
    args = cmdl.parse_args()

    if dbgl > ndt.NONE:
        print("command line arguments:")
        print(f" output directory = {args.odir}")
        print(f" replace directory = {args.rdir}")
        print(f" parameter file = {args.parameters}")
        print(f" model file = {args.model}")
        print(f" argument files = {args.files}")
        
    # display an informational message
    #
    print("beginning decoding...")

    pfile, odir, model_fname, rdir = \
    (args.parameters, args.odir, args.model, args.rdir)
    
    # set the parameter file
    #
    if pfile is None:
        pfile = nft.get_fullpath(DEF_PARAM_FILE)
    else:
        pfile = nft.get_fullpath(pfile)

    # check if the parameter file exists
    #
    if os.path.isfile(pfile) == False:
        print("Error: %s (line: %s) %s: parameter file doesn't exist (%s)" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__, pfile))
        sys.exit(os.EX_SOFTWARE)
        
    # set the output directory
    #
    if odir is None:
        odir = nft.get_fullpath(DEF_ODIR)
    else:
        odir = nft.get_fullpath(odir)

    # set the output directory
    #
    if rdir is not None:
        rdir = nft.get_fullpath(rdir)
  
    # make sure a model filename is specified
    #
    if model_fname is None:
        print("Error: %s (line: %s) %s: must declare a model" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__))
        sys.exit(os.EX_SOFTWARE)
    else:
        model_fname = nft.get_fullpath(model_fname)
    
    # check if the model exists
    #
    if os.path.isfile(model_fname) == False:
        print("Error: %s (line: %s) %s: model file doesn't exist (%s)" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__, model_fname))
        sys.exit(os.EX_SOFTWARE)

    # set the initial model filename
    #
    init_model = nft.get_fullpath(DEF_PRETRAINED_MODEL)
    
    # load the parameter file
    #
    params = nft.load_parameters(pfile, RESNET_DECODE)

    # read the parameter file
    #
    # decoder parameters
    #
    channel_order = eval(params[CH_ORDER])
    DEVICE = torch.device(params[DEVICE_STR]
                          if torch.cuda.is_available() else CPU_STR)

    # get values that need to be processed
    #
    samp_rate = int(float(params[SAMPLE_FREQUENCY]))
    transform_crop = int(params[TRANSFORM_CROP])
    frame_duration = float(params[FRAME_DURATION])

    # calculate the amount of samples needed per frame
    #
    frmsize = int(frame_duration * samp_rate)

    # calculate the transform_order for cropping the RGB image
    #
    transform_order = frmsize - transform_crop

    # get all the necessary image transformations, transform_order and
    # self.frmsize values will be used in the transformations
    #
    transforms = eval(params[TRANSFORMS])
    
    # post processor parameters
    #
    seiz_ths = float(params[SEIZ_TH])
    min_seiz_sec = float(params[MIN_SEIZ])
    min_bckg_sec = float(params[MIN_BCKG])
        
    # create a ResNetDecode object
    #
    resdec = nrd.ResNetDecode(channel_order, frmsize, samp_rate, model_fname,
                              init_model, transforms, DEVICE, seiz_ths,
                              min_bckg_sec, min_seiz_sec)

    # make sure there is a file argument
    #
    if len(args.files) == 0:
        cmdl.print_usage()
        sys.exit(os.EX_SOFTWARE)
        
    # process each file argument
    #
    for arg_file in args.files:

        # check if the argument file exists
        #
        if os.path.isfile(arg_file) == False:
            print("Error: %s (line: %s) %s: file doesn't exist (%s)" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__, arg_file))
            sys.exit(os.EX_SOFTWARE)

        # initialize an Edf object
        #
        check = net.Edf()
 
        # check if the file is a list file or edf file
        #
        if not check.is_edf(arg_file):

            # create a list that will hold all edf filenames
            #
            arg_list = []

            # open the list file
            #
            with open(arg_file, mode=nft.MODE_READ_TEXT) as af:

                # iterate through each filename
                #
                for line in af:

                    # add the filename to the list
                    #
                    arg_list.append(line.strip())

            # decode all the files in the list
            #
            elapsed = resdec.decode_mp(arg_list, odir, rdir)

        # if the file isn't a list file
        #
        else:

            # create a list with the filename
            #
            arg_list = [arg_file]

            # decode the file
            #
            elapsed = resdec.decode_mp(arg_list, odir, rdir)
        
    # display the finished message
    #
    print(f'decoding of all files finished in {elapsed:.0f} seconds.')

    # exit gracefully
    #
    return None

# begin gracefully
#
if __name__ == '__main__':
    main(sys.argv[:])
#
# end of file
