#!/usr/bin/env python
#
# file: $(NEDC_NFC)/util/python/nedc_eeg_decode/classes/nedc_resnet_decode.py
#
# revision history:
#
# 20220703 (ML): initial version
# 20220705 (PM): code review
#
# This file contains a Python implementation of our ResNet18 decoder. Details
# about this system can be found here:
#
#  Khalkhali, V., Shawki, N., Shah, V., Golmohammadi, M., Obeid, I., &
#  Picone, J. (2021). Low Latency Real-Time Seizure Detection Using
#  Transfer Deep Learning. In I. Obeid, I. Selesnick, & J. Picone (Eds.),
#  Proceedings of the IEEE Signal Processing in Medicine and Biology
#  Symposium (SPMB) (pp. 1–7). IEEE.
#  https://doi.org/10.1109/SPMB52430.2021.9672285
#  https://www.isip.piconepress.com/publications/conference_proceedings/2021/ieee_spmb/eeg_transfer_learning/
#
# The API is very simple:
#  constructor: creates the class (called at the top of the program)
#  decode_mp: runs the entire decoding system (should be called after init)
#
# This decoder can decode both individual edf or list files containing edf files
#------------------------------------------------------------------------------

# import system modules
#
import copy
import os
import sys
import time
import torch
import torch.nn as nn
from torchvision import models
import pickle
import numpy as np
import scipy.signal

# import NEDC modules
#
import nedc_ann_eeg_tools as nae
import nedc_edf_tools as net
import nedc_eeg_image as nei
import nedc_debug_tools as ndt
import nedc_file_tools as nft

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------
#

# set the filename using basename
#
__FILE__ = os.path.basename(__file__)

# define values for reading data
#
BCKG, SEIZ = 'bckg', 'seiz'
CLASSES = [BCKG, SEIZ]
FILE_DURATION = 'file_duration'
CSVBI_DATA = 'csvbi_data'
MONTAGES_DATA = 'montages_data'
CSV_BI_EXT = 'csv_bi'
DEF_SECONDS = 'secs'
SAMP_FREQ = 'samp_freq'

# set signal convolution options
#
SET_FULL = 'full'
SET_VALID = 'valid'

# declare a global debug object so we can use it in functions
#
dbgl = ndt.Dbgl()

#------------------------------------------------------------------------------
#
# classes are listed here
#
#------------------------------------------------------------------------------

# class: ResNetDecode
#
# This class can decode edf files using a given resnet 18 model
#
class ResNetDecode:

    # method: ResNetDecode::constructor
    #
    # arguments:
    #  channels_order: the order of raw channels
    #  frmsize: the amount of samples per frame
    #  samp_rate: the sample rate of the data in Hz
    #  mdl_fname: the model filename
    #  init_model: the initial model filename
    #  transforms: the transformations on the input image
    #  device: the device that the training will be run on    
    #  seiz_ths: the minimum probability as to when a hypothesis is designated
    #    as seizure
    #  min_bckgs: the minimum time in seconds for a hypothesis to be designated
    #    as background
    #  min_seizs: the minimum time in seconds for a hypothesis to be designated
    #    as seizure
    #
    # returns:
    #  None
    #
    # this simple method is the constructor for the class
    #
    def __init__(self, channels_order, frmsize, samp_rate, mdl_fname,
                 init_model, transforms, device, seiz_ths, min_bckgs,
                 min_seizs):

        # declare class data
        #
        self.channels_order = channels_order
        self.frmsize = frmsize
        self.samp_rate = samp_rate
        self.mdl_fname = mdl_fname
        self.transforms = transforms
        self.device = device
        self.seiz_threshold = seiz_ths
        self.min_bckg_secs = int(min_bckgs)
        self.min_seiz_secs = int(min_seizs)

        # load the initial model
        #
        model = torch.load(init_model)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, len(CLASSES))

        # load the parameers from the designated model onto the initial model
        #
        if self.mdl_fname is not None:
            model.load_state_dict(torch.load(self.mdl_fname,
                                             map_location=self.device))
        self.model = copy.deepcopy(model).to(self.device)

        # exit gracefully
        #
        return None

    #
    # end of method

    #--------------------------------------------------------------------------
    #
    # file-based methods
    #
    #--------------------------------------------------------------------------

    # method: ResNetDecode::decode_mp
    #
    # arguments:
    #   arg_list: list of files to process
    #   odir: the output directory to write the csvbi files to
    #   rdir: the replace directory
    #
    # return: 
    #  time.time() - init_time: the total time needed to process all files
    #
    # This function will decode each file given in arg_list and
    # write the output to a csv_bi file
    #
    def decode_mp(self, arg_list, odir, rdir):

        # keep track of the start time
        #
        init_time = time.time()

        # check if odir exists
        #
        os.makedirs(odir, exist_ok=True)

        # loop over all files
        #
        for argfile in arg_list:

            # decode each file
            #
            self.decode_single_file(argfile, odir, rdir)

        # exit gracefully
        #
        return time.time() - init_time

    #
    # end of method

    # method: ResNetDecode::decode_single_file
    #
    # arguments:
    #   argfile: file name to process
    #   odir: the output directory to write the csv files
    #   rdir: the replace directory
    #
    # return:
    #  None
    #
    # This function decodes an edf file and writes
    # the output to a csv_bi file
    #
    def decode_single_file(self, argfile, odir, rdir):

        # make sure to expand environment variables
        #
        argfile = nft.get_fullpath(argfile)

        # check if the argument file exists
        #
        if os.path.isfile(argfile) == False:
            print("Error: %s (line: %s) %s: file doesn't exist (%s)" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__, argfile))
            sys.exit(os.EX_SOFTWARE)

        # initialize an Edf object
        #
        check = net.Edf()

        # make sure the file is a valid edf file
        #
        if not check.is_edf(argfile):
            print("Error: %s (line: %s) %s: file isn't a valid edf file (%s)" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__, argfile))
            sys.exit(os.EX_SOFTWARE)

        # check if odir exists
        #
        os.makedirs(odir, exist_ok=True)

        # make dataset
        #
        dataset = self.get_image_dataset(argfile)

        # get seizure probabilities based on the dataset
        #
        all_probs = self.get_detects(dataset)

        # postprocess the probabilities
        #
        all_probs_final = np.int64(self.postprocess(all_probs,
                                   self.min_seiz_secs, self.min_bckg_secs))

        # convert the probabilities to time based detections
        #
        time_detects = self.convert_to_time_based(all_probs_final)

        # write the csv_bi file
        #
        self.write_data_to_csv_bi(argfile, time_detects, odir, rdir)

        # exit gracefully
        #
        return None
    #
    # end of method

    # method: ResNetDecode::get_image_dataset
    #
    # arguments:
    #  fname: filename to get data from
    #
    # return: 
    #  dataset: the created image dataset
    #
    # This function will turn edf file data into an image dataset
    #
    def get_image_dataset(self, fname):

        # make sure file exists
        #
        if os.path.exists(fname) == False:
            print("Error: %s (line: %s) %s: file does not exist (%s)" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__, fname))
            sys.exit(os.EX_SOFTWARE)

        # create the dataset
        #
        dataset = nei.SingleEDFImages(fname, self.channels_order,
                                      self.samp_rate, self.frmsize,
                                      self.transforms)
        
        # exit gracefully
        #
        return dataset
    #
    # end of method

    # method: ResNetDecode::write_data_to_csv_bi
    #
    # arguments:
    #  argfile: filename to get data from
    #  tdets: list of a list of each time based detection
    #  odir: output csv_bi file directory
    #  rdir: the replace directory
    #
    # return:
    #  None
    #
    # This function will write the data to the csv file
    #
    def write_data_to_csv_bi(self, argfile, tdets, odir, rdir):

        # create an Edf object
        #
        edf = net.Edf()

        # load the header data
        #
        edf.get_header_from_file(argfile)

        # get file_duartion and montage_file data
        #
        file_duration = edf.get_duration()
        
        # set the montage file
        #
        montage_file = nae.DEFAULT_MONTAGE_FNAME
        
        # create the final csvbi filename
        #
        dec_csv_fname = nft.create_filename(argfile, odir, CSV_BI_EXT, rdir)

        # initialize a Csv object
        #
        ann = nae.Csv(montage_f = montage_file)

        # set the file duration
        #
        ann.set_file_duration(float(file_duration))
        tdets[-1][1] = file_duration

        # create file graph object
        #
        for event in tdets:
            ann.graph_d.create(int(0), int(0), int(-1),
                            float(event[0]), float(event[1]),
                            {CLASSES[int(event[2])]: float(1.0000)})

        # write to the csvbi file
        #
        ann.write(dec_csv_fname, int(0), int(0))

        # exit gracefully
        #
        return None
    #
    # end of method

    # method: ResNetDecode::get_detects
    #
    # arguments:
    #  dataset: image dataset to decode
    #
    # return:
    #  all_probs: list of all seizure probabilities
    #
    # This function converts an image dataset to probabilities of siezure event
    #
    def get_detects(self, dataset):

        # save the model state to restore it at the end
        #
        was_training = self.model.training

        # set model to eval mode
        #
        self.model.eval()

        # extract all probabilities
        #
        all_probs = []

        with torch.no_grad():

            # go through the dataset and collect the seiz probs
            #
            for i, inputs in enumerate(dataset):
                inputs = inputs.unsqueeze(dim=0)
                inputs = inputs.to(self.device)
                outputs = self.model(inputs)
                seiz_probs = torch.softmax(outputs, 1)[:, 1]
                all_probs += seiz_probs.cpu().numpy().tolist()

        # check if the length of the EDF file is less than
        # window size
        #
        if len(all_probs) == 0:
            all_probs = [0.0]

        # change the model state to last state
        #
        self.model.train(mode=was_training)

        # exit gracefully
        #
        return all_probs
    #
    # end of method

    # method: ResNetDecode::convert_to_time_based
    #
    # arguments:
    #  all_probs: list of all seizure probabilities
    #
    # return: 
    #  tdets: time based detections
    #
    # This function converts a list of probabilities
    # to time based detections
    #
    def convert_to_time_based(self, all_probs):

        tdets = []

        # get the first event to initialize tdets
        #
        curr_event = all_probs[0]
        tdets.append([0, None, curr_event])

        # convert all_probs to time based
        # 
        for ecounter, event in enumerate(all_probs):
            if event != curr_event:
                
                curr_event = event
                
                # calculate start time
                #
                curr_time = ecounter * self.frmsize / self.samp_rate
                
                # set element in list before current element to
                # end at start time of current element
                #
                tdets[-1][1] = curr_time
                tdets.append([curr_time, None, curr_event])

        # set the end time of the last event
        #
        tdets[-1][1] = len(all_probs) * self.frmsize / self.samp_rate

        # exit gracefully
        #
        return tdets
    #
    # end of method

    # method: ResNetDecode::postprocess
    #
    # arguments:
    #   all_probs: list of all seizure probabilities
    #   mins: minimum acceptable seizure duration
    #   minb: minimum acceptable background duration
    #
    # return:
    #   detection: the postprocessed binary detection
    #
    # This function postprocess the seizure probabilities
    #
    def postprocess(self, all_probs, mins, minb):

        # load decodes to check if confidence is above threshold
        #
        detects = np.array(all_probs)
        detection = 1 * (detects > self.seiz_threshold)

        # number of samples for smallest acceptable seizure and background
        #
        min_lb = int(np.round(minb * self.samp_rate / self.frmsize))
        min_ls = int(np.round(mins * self.samp_rate / self.frmsize))

        # Returns detection after post-processing
        #
        if minb > 0:
            detection = self.pp_closing(detection, min_lb)
            detection = self.pp_opening(detection, min_lb)

        if mins > 0:
            detection = self.pp_rm_seiz(detection, min_ls)

        # exit gracefully
        #
        return detection
    #
    # end of method

    # method: ResNetDecode::pp_rm_seiz
    #
    # arguments:
    #   dets: sequence of binary detections
    #   mins: minimum acceptable seizure duration
    #
    # return:
    #   dets: the processed binary detection
    #
    # This function removes noisy seizure events
    #
    def pp_rm_seiz(self, dets, mins):

        # removes seizures with length out of the range [mins, maxs]
        #
        conv_win = np.ones(mins)
        dets = 1.0 * (scipy.signal.convolve(dets, conv_win,
                                            mode=SET_FULL) == mins)
        dets = 1.0 * (scipy.signal.convolve(dets, conv_win,
                                            mode=SET_VALID) > 0)

        # exit gracefully
        #
        return dets
    #
    # end of method

    # method: ResNetDecode::pp_closing
    #
    # arguments:
    #   dets: sequence of binary detections
    #   diameter: adds [diameter/2] to left and right side
    #
    # return:
    #   dets: the processed binary detection
    #
    # This function adds [diameter/2] to left and right side of every
    # seizure event.
    #
    def pp_closing(self, dets, diameter):

        # adds [diameter/2] to left and right side
        #
        conv_win = np.ones(diameter)
        dets = 1.0 * (scipy.signal.convolve(dets, conv_win,
                                            mode=SET_FULL) > 0)
        dets = 1.0 * (scipy.signal.convolve(dets, conv_win,
                                            mode=SET_VALID) > 0)

        # exit gracefully
        #
        return dets
    #
    # end of method

    # method: ResNetDecode::pp_opening
    #
    # arguments:
    #   dets: sequence of binary detections
    #   diameter: removes [diameter-1] from left and right side
    #
    # return:
    #   dets: the processed binary detection
    #
    # This function removes [diameter-1] from left and right side of every
    # seizure event.
    #
    def pp_opening(self, dets, diameter):

        # removes [diameter-1] from left and right side
        #
        conv_win = np.ones(diameter)
        dets = 1.0 * (scipy.signal.convolve(dets, conv_win,
                                            mode=SET_FULL) == diameter)

        dets = 1.0 * (scipy.signal.convolve(dets, conv_win,
                                            mode=SET_VALID) == diameter)

        # exit gracefully
        #
        return dets
    #
    # end of method

#
# end of class

#
# end of file
