#!/usr/bin/env python
#
# file: $(NEDC_NFC)/src/tools/nedc_eeg_resnet_train/nedc_eeg_resnet_train.py
#
# revision history:
#
# 20211117 (PM): refactored code
# 20210422 (VK): initial version
#
# This is a Python script that trains the ResNet18 model with the given datasets
#------------------------------------------------------------------------------

# import system modules
#
import numpy as np
import os
import sys
from sklearn.metrics import confusion_matrix
import torch
from torchvision import models, transforms

# import nedc_modules
#
import nedc_cmdl_parser as ncp
import nedc_debug_tools as ndt
import nedc_eeg_image as nei
import nedc_resnet_train as nrt
import nedc_file_tools as nft

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

# define the system name
#
EEG_RESNET = 'EEG_RESNET'

# define the default device
#
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# set the filename using basename
#                                                                              
__FILE__ = os.path.basename(__file__)

# define the script location
#
SCRIPT_LOC = os.path.dirname(os.path.realpath(__file__))
 
# define the location of the help files
#
HELP_FILE = \
  "$NEDC_NFC/src/python/util/nedc_eeg_resnet_train/nedc_eeg_resnet_train.help"

USAGE_FILE = \
  "$NEDC_NFC/src/python/util/nedc_eeg_resnet_train/nedc_eeg_resnet_train.usage"

# define the program options:                                                  
#  note that you cannot separate them by spaces                                
#
ARG_PARM = "--parameters"
ARG_ABRV_PARM = "-p"

ARG_TRAIN_EDF = "--train_edf"
ARG_ABRV_TRAIN_EDF = "-t"

ARG_TRAIN_CSV = "--train_csv"
ARG_ABRV_TRAIN_CSV = "-c"

ARG_DEV_EDF = "--dev_edf"
ARG_ABRV_DEV_EDF = "-e"

ARG_DEV_CSV = "--dev_csv"
ARG_ABRV_DEV_CSV = "-d"

# define default values for arguments
#
DEF_PARAM_FILE = \
    "$NEDC_NFC/lib/nedc_eeg_resnet_train_param.txt"

# define the initial model
#
DEF_INITIAL_MODEL = "$NEDC_NFC/models/resnet-18_pretrained.pth"

# define confusion matrix options
#
DEF_TRUE = 'true'

# training parameters
#
BAT_SIZE = 'bat_size'
NUM_EPOCHS = 'num_epochs'
NSAMPLES_TRAIN = 'nsamples_train'
NSAMPLES_DEV = 'nsamples_dev'
NUM_WORKERS = 'num_workers'
SAMPLE_FREQUENCY = 'sample_frequency'
FRAME_DURATION = 'frame_duration'
TRANSFORM_CROP = 'transform_crop'
TRAIN_TRANSFORMS = 'train_transforms'
DEV_TRANSFORMS = 'dev_transforms'
EVAL_TRANSFORMS = 'eval_transforms' 
COMP_DEVICE = 'device'
CPU = 'cpu'
PRECISION = 'precision'
SAVE_EPOCH_MODEL = 'save_model_every_epoch'
CH_ORDER = 'channel_order'

# define the required number of arguments
#
NUM_ARGS = 1

#------------------------------------------------------------------------------
#
# functions are listed here
#
#------------------------------------------------------------------------------

# declare a global debug object so we can use it in functions
#
dbgl = ndt.Dbgl()

# function: main
#
def main(argv):
    
    # create a command line parser 
    #                                                                          
    cmdl = ncp.Cmdl(USAGE_FILE, HELP_FILE)

    # define the command line arguments
    #
    cmdl.add_argument("files", type = str, nargs = '*')
    cmdl.add_argument(ARG_ABRV_PARM, ARG_PARM, type = str)
    cmdl.add_argument(ARG_ABRV_TRAIN_EDF, ARG_TRAIN_EDF, type = str)
    cmdl.add_argument(ARG_ABRV_TRAIN_CSV, ARG_TRAIN_CSV, type = str)
    cmdl.add_argument(ARG_ABRV_DEV_EDF, ARG_DEV_EDF, type = str)
    cmdl.add_argument(ARG_ABRV_DEV_CSV, ARG_DEV_CSV, type = str)
    
    # parse the command line
    #
    args = cmdl.parse_args()

    if dbgl > ndt.NONE:
        print("command line arguments:")
        print(f" parameter file = {args.parameters}")
        print(f" train edf list = {args.train_edf}")
        print(f" train csv list = {args.train_csv}")
        print(f" dev edf list = {args.dev_edf}")
        print(f" dev csv list = {args.dev_csv}")
        print(f" model filename = {args.files}")
        
    # display an informational message
    #
    print("beginning training...")
    
    pfile, train_edf, train_csv, dev_edf, dev_csv, model_fname = \
                    (args.parameters, args.train_edf, args.train_csv,
                     args.dev_edf, args.dev_csv, args.files[0])

    # set the parameter file
    #
    if pfile is None:
        pfile = nft.get_fullpath(DEF_PARAM_FILE)
    else:
        pfile = nft.get_fullpath(pfile)

    # make sure the model file is specified
    #
    if model_fname is None:
        print("Error: %s (line: %s) %s: must specify an output model name" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__))
        sys.exit(os.EX_SOFTWARE)

    # set the model filename
    #
    model_fname = nft.get_fullpath(model_fname)
        
    # check if the parameter file exists
    #
    if os.path.isfile(pfile) == False:
        print("Error: %s (line: %s) %s: parameter file doesn't exist (%s)" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__, pfile))
        sys.exit(os.EX_SOFTWARE)

    # set the train edf list file
    #
    train_edf = nft.get_fullpath(train_edf)

    # check if the train edf list file exists
    #
    if os.path.isfile(train_edf) == False:
        print("Error: %s (line: %s) %s: train edf file doesn't exist (%s)" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__, train_edf))
        sys.exit(os.EX_SOFTWARE)

    # set the train csv list file
    #
    train_csv = nft.get_fullpath(train_csv)
    
    # check if the train csv list file exists
    #
    if os.path.isfile(train_csv) == False:
        print("Error: %s (line: %s) %s: train csv file doesn't exist (%s)" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__, train_csv))
        sys.exit(os.EX_SOFTWARE)

    # set the dev edf list file
    #
    dev_edf = nft.get_fullpath(dev_edf)

    # check if the dev edf list file exists
    #
    if os.path.isfile(dev_edf) == False:
        print("Error: %s (line: %s) %s: dev edf file doesn't exist (%s)" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__, dev_edf))
        sys.exit(os.EX_SOFTWARE)

    # set the dev csv list file
    #
    dev_csv = nft.get_fullpath(dev_csv)

    # check if the dev csv list file exists
    #
    if os.path.isfile(dev_csv) == False:
        print("Error: %s (line: %s) %s: dev csv file doesn't exist (%s)" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__, dev_csv))
        sys.exit(os.EX_SOFTWARE)

    # load parameter file
    #
    params = nft.load_parameters(pfile, EEG_RESNET)

    # read the training parameters from the parameter file
    #
    bat_size = int(params[BAT_SIZE])
    num_epochs = int(params[NUM_EPOCHS])
    num_workers = int(params[NUM_WORKERS])
    save_epoch_model = eval(params[SAVE_EPOCH_MODEL].capitalize())
    channel_order = eval(params[CH_ORDER])
    
    # get values that need to be processed
    #
    sample_rate = int(float(params[SAMPLE_FREQUENCY]))
    transform_crop = int(params[TRANSFORM_CROP])
    frame_duration = float(params[FRAME_DURATION])

    # calculate the amount of samples needed per frame
    #
    frmsize = int(frame_duration * sample_rate)

    # calculate the transform_order for cropping the RGB image
    #
    transform_order = frmsize - transform_crop

    # get all the necessary image transformations, transform_order and
    # frmsize values will be used in the transformations
    #
    train_transforms = eval(params[TRAIN_TRANSFORMS])
    dev_transforms = eval(params[DEV_TRANSFORMS])
    
    # set the device for training
    #
    DEVICE = torch.device(params[COMP_DEVICE]
                          if torch.cuda.is_available() else CPU)

    # get the train file list
    #
    train_edf_list = [file.strip() for file in nft.get_flist(train_edf)]
    train_csvbi_list = [file.strip() for file in nft.get_flist(train_csv)]

    # get the dev file list
    #
    dev_edf_list = [file.strip() for file in nft.get_flist(dev_edf)]  
    dev_csvbi_list = [file.strip() for file in nft.get_flist(dev_csv)]
    
    # make an instance of EEGResNet
    #
    resnet = nrt.EEGResNet(train_edf_list, train_csvbi_list, dev_edf_list,
                           dev_csvbi_list, frmsize, train_transforms,
                           dev_transforms, DEVICE, channel_order, sample_rate)

    # print the information about the training samples
    #
    print(f'Train dataset: {resnet.train_nsamples}')
    print(f'Train weights: {resnet.train_weights}')

    # print information about the dev samples
    #
    print(f'Dev dataset: {resnet.dev_nsamples}')
    print(f'Dev weights: {resnet.dev_weights}')

    # train the model 
    #
    model = resnet.train(DEF_INITIAL_MODEL, bat_size, num_epochs,
                         num_workers, save_epoch_model, model_fname)

    # save the trained model
    #
    if model_fname is not None:
        if os.path.dirname(model_fname) != nft.DELIM_NULL:
            os.makedirs(os.path.dirname(model_fname), exist_ok=True)
        torch.save(model.state_dict(), model_fname)

    # evaluate model on train set and print the information
    #
    all_labels, all_preds, train_acc = resnet.decode(model,
                                                     resnet.train_dataloader,
                                                     device=DEVICE)

    # create the confusion matrices
    #
    cm_raw = confusion_matrix(all_labels, all_preds, normalize=None)    
    cm_rows = confusion_matrix(all_labels, all_preds, normalize=DEF_TRUE)

    # set the precision for all values
    #
    np_precision = np.get_printoptions()[PRECISION]
    np.set_printoptions(precision=2, suppress=True)

    # print accuracy and the confusion matrices
    #
    print(f'\nAccuracy on train set: {train_acc:.2f}')
    print('Confusion matrix on train set: \n', cm_raw, nft.DELIM_NEWLINE)
    print('Confusion matrix on train set (ratio): \n', cm_rows,
          nft.DELIM_NEWLINE)

    # evaluate model on dev set
    #
    all_labels, all_preds, dev_acc = resnet.decode(model,
                                                   resnet.dev_dataloader,
                                                   device=DEVICE)

    # create the confusion matrices
    #
    cm_raw = confusion_matrix(all_labels, all_preds, normalize=None)    
    cm_rows = confusion_matrix(all_labels, all_preds, normalize=DEF_TRUE)

    # set the precision for all values
    #
    np_precision = np.get_printoptions()[PRECISION]
    np.set_printoptions(precision=2, suppress=True)

    # print accuracy and the confusion matrices
    #
    print(f'\nAccuracy on dev set: {dev_acc:.2f}')
    print('Confusion matrix on dev set: \n', cm_raw, nft.DELIM_NEWLINE)
    print('Confusion matrix on dev set (ratio): \n', cm_rows, nft.DELIM_NEWLINE)

    # recover the last numpy precision
    #
    np.set_printoptions(precision=np_precision, suppress=False)

    # exit gracefully
    #
    return None
#
# end of main

# begin gracefully
#
if __name__ == "__main__":
    main(sys.argv[0:])

#
# end of file
