#!/usr/bin/env python
#
# file: $(NEDC_NFC)/tool/nedc_dpath_resnet_decode_slide/nedc_dpath_resnet_decode_slide.py
#
# revision history:
#
# 20230929 (SM): prep for v2.0.0 release
# 20211221 (PM): refactored code
# 20210416 (VK): adapted to a single svs slide
# 20210128 (VK): second version
# 20200904 (VK): initial version
#
# This a Python script will decode a whole slide and makes a csv output 
# detection file and an image mask with predefined colors.
#------------------------------------------------------------------------------

# import system modules
#
import os
import sys
import time
import numpy as np
import torch
from torchvision import transforms, models
from datetime import datetime

# import NEDC support modules
#
import nedc_cmdl_parser as ncp
import nedc_dpath_decode_slide as ndds
import nedc_debug_tools as ndt
import nedc_file_tools as nft

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

DECODE_PARAM = 'DECODE_SLIDE'
TRAIN, DEV, EVAL = 'train', 'dev', 'eval'
IMG_EXT = '*.tif'

# set the filename using basename                                              
#                                                                              
__FILE__ = os.path.basename(__file__)

# define script location
#
SCRIPT_LOC = os.path.dirname(os.path.realpath(__file__))
 
# define the help file and usage message
#
HELP_FILE = \
    "$NEDC_NFC/src/python/util/nedc_dpath_resnet_decode_slide/nedc_dpath_resnet_decode_slide.help"

USAGE_FILE = \
    "$NEDC_NFC/src/python/util/nedc_dpath_resnet_decode_slide/nedc_dpath_resnet_decode_slide.usage"

# define the program options:                                                  
#  note that you cannot separate them by spaces                                
#
ARG_PARM = "--parameters"
ARG_ABR_PARM = "-p"

ARG_MODEL = "--model"
ARG_ABR_MODEL = "-m"

ARG_CSV_DIR = "--csv"
ARG_ABR_CSV_DIR = "-c"

ARG_MASK_DIR = "--mask"
ARG_ABR_MASK_DIR = "-k"

# define the required number of arguments
#
NUM_ARGS = 1

# define default values for arguments:
#  note we assume the parameter file is in the same
#  directory as the source code.
#
DEF_PARAM_FILE = 'nedc_dpath_resnet_decode_slide'
DEF_CSV_DIR = './csv'
DEF_MASK_DIR = './csv'

# pretrained model parameters
# This part will be used to load pretrained model from PyTorch
#
DEF_MODEL_DIR = os.environ['PRETRAINED_MODEL_DIR']
os.makedirs(DEF_MODEL_DIR, exist_ok=True)
os.environ['TORCH_HOME'] = DEF_MODEL_DIR
DEF_INITIAL_MODEL = models.resnet18(pretrained=True)

# Transforms are different for each class
# Note: If one of the sets is not used, equal its value to empty string.
#
MODEL_TRANSFORMS = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#------------------------------------------------------------------------------
#
# functions are listed here
#
#------------------------------------------------------------------------------

# declare a global debug object
#
dbgl = ndt.Dbgl()

# function: main
#
def main(argv):

    # create a command line parser                                        
    #                                                                          
    cmdl = ncp.Cmdl(USAGE_FILE, HELP_FILE)

    # define the command line arguments
    #
    cmdl.add_argument("files", type = str, nargs = '*')
    cmdl.add_argument(ARG_ABR_PARM, ARG_PARM, type = str)
    cmdl.add_argument(ARG_ABR_MODEL, ARG_MODEL, type = str)
    cmdl.add_argument(ARG_ABR_CSV_DIR, ARG_CSV_DIR, type = str)
    cmdl.add_argument(ARG_ABR_MASK_DIR, ARG_MASK_DIR, type = str)

    # parse the command line
    #
    args = cmdl.parse_args()

    # check if the proper number of lists has been provided
    #
    if len(args.files) != NUM_ARGS:
        cmdl.print_usage('stdout')
        sys.exit(os.EX_SOFTWARE)

    svs_fnames, pfile, model_fname, csv_dir, mask_dir = \
        (args.files[0], args.parameters, args.model, args.csv, args.mask)

    # set argument values
    #
    if pfile == None:
        pfile = DEF_PARAM_FILE
    if csv_dir == None:
        csv_dir = DEF_CSV_DIR
    if mask_dir == None:
        csv_dir = DEF_MASK_DIR

    if dbgl > ndt.NONE:
        print("command line arguments:")
        print(" parameter file = %d" % (args.parameters))
        print(" model = %d" % (args.model))
        print(" csv directory = %d" % (args.csv))
        print(" mask directory = %d" % (args.mask))
        print("")

    # get absolute path for the four files: 
    #    svs_fnames, model_fname, csv_dir, mask_dir
    (svs_fnames, model_fname, csv_dir, mask_dir) = list(
        map(os.path.abspath, [svs_fnames, model_fname, csv_dir, mask_dir]))

    # read svs file names in the list
    #
    with open(svs_fnames, mode='r') as file:
        svs_flist = [fname.strip() for fname in file.readlines()]

    # model file name must have been specified
    #
    if model_fname is None:
        cmdl.print_usage('stdout')
        sys.exit(os.EX_SOFTWARE)

    # read parameters
    #
    params = nft.load_parameters(pfile, DECODE_PARAM)

    # set device
    #
    DEVICE = torch.device(params['device']
                          if torch.cuda.is_available() else "cpu")
    
    # make an instance of Decode class
    #
    classes = params['acc_lbls']
    dcd = ndds.DecodeSlide(DEF_INITIAL_MODEL, model_fname, svs_flist,
                      classes,
                      MODEL_TRANSFORMS, int(params['bat_size']),
                      int(params['num_workers']), int(params['win_len']),
                      int(params['frm_len']), DEVICE)

    # decoding the svs files
    #
    init_time = time.time()
    print(f"Start decoding at {datetime.now().strftime('%m/%d/%Y, %H:%M:%S')}")
    decode_list = dcd.decode_flist()
    
    # write the probabilities to a confidence level csv, nedc csv, and an xml
    # file
    #
    dcd.write_probabilities(decode_list, csv_dir, classes)

    # save the masks
    #
    pixel_size = params['pixel_size']
    pixel_size = list(map(int, pixel_size))
    colors = eval(params['colors'])
    colors = {key: colors[i] for i, key in enumerate(classes)}
    dcd.save_masks(decode_list, mask_dir, classes, colors, pixel_size)

    print(f'Finished in {time.time() - init_time:.0f} seconds')
#
# end of main

# begin gracefully
#
if __name__ == "__main__":
    main(sys.argv[0:])

#
# end of file
