#!/usr/bin/env python
#
# file: $(NEDC_NFC)/tool/nedc_dpath_resnet_train/nedc_dpath_resnet_train.py
#
# revision history:
#
# 20230929 (SM): prep for v2.0.0 release
# 20211221 (PM): refactored code
# 20210202 (VK): replaced file lists instead of ImageFolder
# 20210221 (VK): dev set samples can be selective too.
#                The best model is selected based on loss, not accuracy.
# 20210128 (VK): second version
# 20200904 (VK): initial version
#
# This is a Python script that re-train a ResNet18 model to fit to the new slides
#------------------------------------------------------------------------------

# import system modules
#
import os
import sys
import numpy as np
from sklearn.metrics import confusion_matrix
import torch
from torchvision import models, transforms

# import NEDC support modules
#
import nedc_cmdl_parser as ncp
import nedc_debug_tools as ndt
from nedc_dpath_image import ImagesList
import nedc_file_tools as nft
import nedc_dpath_train as ndpt

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

TRAIN_PARAM = 'TRAIN'
TRAIN, DEV, EVAL = 'train', 'dev', 'eval'
IMG_EXT = '*.tif'

# set the filename using basename                                              
#                                                                              
__FILE__ = os.path.basename(__file__)

# define script location
#
SCRIPT_LOC = os.path.dirname(os.path.realpath(__file__))
 
# define the help file and usage message
#
HELP_FILE = \
    "$NEDC_NFC/src/python/util/nedc_dpath_resnet_train/nedc_dpath_resnet_train.help"

USAGE_FILE = \
    "$NEDC_NFC/src/python/util/nedc_dpath_resnet_train/nedc_dpath_resnet_train.usage"

# define the program options:                                                  
#  note that you cannot separate them by spaces                                
#
ARG_PARM = "--parameters"
ARG_ABRV_PARM = "-p"

ARG_TRAIN_LIST = "--train_list"
ARG_ABRV_TRAIN_LIST = "-t"

ARG_DEV_LIST = "--dev_list"
ARG_ABRV_DEV_LIST = "-d"

# define default values for arguments:
#  note we assume the parameter file is in the same
#  directory as the source code.
#
DEF_PARAM_FILE = 'nedc_dpath_resnet_train_param'
DEF_MODEL_FNAME = "model.pckl"

# load the pretrained model
#
DEF_MODEL_DIR = os.environ['PRETRAINED_MODEL_DIR']
os.makedirs(DEF_MODEL_DIR, exist_ok=True)
os.environ['TORCH_HOME'] = DEF_MODEL_DIR
DEF_INITIAL_MODEL = models.resnet18(pretrained=True)

# Transforms are different for each class
# Note: If one of the sets is not used, its value is an empty string.
#
TRAIN_TRANSFORMS = transforms.Compose(
    [transforms.RandomResizedCrop(224),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
DEV_TRANSFORMS = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
EVAL_TRANSFORMS = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#------------------------------------------------------------------------------
#
# functions are listed here
#
#------------------------------------------------------------------------------

# declare a global debug object
#
dbgl = ndt.Dbgl()

# function: main
#
def main(argv):

    # declare local variables
    #
    status = True
    
    # create a command line parser                                        
    #                                                                          
    cmdl = ncp.Cmdl(USAGE_FILE, HELP_FILE)

    # define the command line arguments
    #
    cmdl.add_argument("files", type = str, nargs = '*')
    cmdl.add_argument(ARG_ABRV_PARM, ARG_PARM, type = str)
    cmdl.add_argument(ARG_ABRV_TRAIN_LIST, ARG_TRAIN_LIST, type = str)
    cmdl.add_argument(ARG_ABRV_DEV_LIST, ARG_DEV_LIST, type = str)

    # parse the command line
    #
    args = cmdl.parse_args()

    # place command line arguments into variables
    #
    model_fname, pfile, train_list, dev_list = \
        (args.files[0], args.parameters, args.train_list, args.dev_list)

    # set argument values
    #
    if pfile == None:
        pfile = DEF_PARAM_FILE
    if model_fname == None:
        model_fname = DEF_MODEL_FNAME
    
    if dbgl > ndt.NONE:
        print("command line arguments:")
        print(" parameter file = %d" % (args.parameters))
        print("")

    # both train list and dev list needs to be determined before training
    #
    if (train_list is None) | (dev_list is None):
        cmdl.print_usage('stdout')
        sys.exit(os.EX_SOFTWARE)

    # read parameters
    #
    params = nft.load_parameters(pfile, TRAIN_PARAM)

    # reading directories
    #
    DEVICE = torch.device(params['device']
                          if torch.cuda.is_available() else "cpu")

    # make an instance of DpathTrain
    #
    nsamples_per_class = {TRAIN: int(params['nsamples_train']),
                          DEV: int(params['nsamples_dev'])}
    dptr = ndpt.DpathTrain(train_list, dev_list, nsamples_per_class,
                          TRAIN_TRANSFORMS, DEV_TRANSFORMS, DEVICE)
    
    # find some information about dataset and print them out
    #
    dataset_sizes = {TRAIN: len(dptr.train_subdataset),
                     DEV: len(dptr.dev_subdataset)}
    print(f'Class names: {dptr.class_names}')

    # print train information
    #
    print(f'Number of train available samples: {dptr.train_nsamples}')
    print(f'Number of train samples that will be used:' +
          f' {dptr.train_nsamples_subdataset}')
    print(f'Train subdataset weights: {dptr.train_weights}')

    # print dev information
    #
    print(f'Number of dev available samples: {dptr.dev_nsamples}')
    print(f'Number of dev samples that will be used:'+
          f' {dptr.dev_nsamples_subdataset}')
    print(f'Dev subdataset weights: {dptr.dev_weights}')
    print(f'Dataset size: {dataset_sizes}')
        
    # print device information
    #
    print(f'Device: {dptr.device}')

    # train all layers
    #
    model = dptr.train(DEF_INITIAL_MODEL, int(params['bat_size']),
                       int(params['num_epochs']), int(params['num_workers']))

    # Save the trained model
    #
    if model_fname is not None:
        if os.path.dirname(model_fname) != '':
            os.makedirs(os.path.dirname(model_fname), exist_ok=True)
        torch.save(model.state_dict(), model_fname)   
#
# end of main

# begin gracefully
#
if __name__ == "__main__":
    main(sys.argv[0:])

#                                                                              
# end of file
