#!/usr/bin/env python

# file: decoder.py
#                                                                              
# revision history:                                                            
#
# 20230316 (ML): initial version                                              
#
# This file contains a Python implementation of our ResNet18 decoder. Details
# about this system can be found here:
#
#  Khalkhali, V., Shawki, N., Shah, V., Golmohammadi, M., Obeid, I., &
#  Picone, J. (2021). Low Latency Real-Time Seizure Detection Using
#  Transfer Deep Learning. In I. Obeid, I. Selesnick, & J. Picone (Eds.),
#  Proceedings of the IEEE Signal Processing in Medicine and Biology
#  Symposium (SPMB) (pp. 1–7). IEEE.
#  https://doi.org/10.1109/SPMB52430.2021.9672285
#  https://www.isip.piconepress.com/publications/conference_proceedings/2021/ieee_spmb/eeg_transfer_learning/
# 
# The API is very simple:
#  constructor: creates the class (called at the top of the program)
#  init: initializes ths system (can be called any time)
#  process: runs the decoder (should be called after init)
#  flush: flushes the system (init should be called after this)
#
# Data processing runs in real-time with a small latency (300 ms).
# Postprocessing adds a significant amount of delay (about 150 secs).
# Postprocessing can be tuned to minimize delay.
#------------------------------------------------------------------------------

# import required system modules                                               
#
import copy
import numpy as np
import os
from PIL import Image
import sys
import torch
import torch.nn as nn
from torchvision import models, transforms

# import NEDC modules
#
import nedc_postprocess as ppr
import nedc_file_tools as nft
import nedc_edf_tools as net
import nedc_ann_eeg_tools as nae

#------------------------------------------------------------------------------
#                                                                              
# global variables are listed here
#                                                                              
#------------------------------------------------------------------------------

# initial model data
#
DEF_PRETRAINED_MODEL = '$NEDC_NFC/models/resnet-18_pretrained.pth'

# define a constant to scale the signal data to [0, 255]
# in order to create a grayscale image
#
EPSILON = 1e-6

# define a constant for the type of image the grayscale image will
# be converted to
#
DEF_RGB = 'RGB'

# decoding parameters
#
RESNET_DECODE = 'RESNET_DECODE'
SAMPLE_FREQUENCY = 'sample_frequency'
CHANNEL_ORDER = 'channel_order'
MONTAGE_ORDER = 'montage_order'
TRANSFORMS = 'transforms'
DEVICE_STR = 'device'
CPU_STR = 'cpu'
SEIZ_TH = 'seizure_threshold'
FRAME_DURATION = 'frame_duration'
MDL_FNAME = 'mdl_fname'
MIN_BCKG = 'minimum_background_duration'
MIN_SEIZ = 'minimum_seizure_duration'
WINSIZE = 'postprocessing_window_duration'
TRANSFORM_CROP = 'transform_crop'
CSV_BI_EXT = 'csv_bi'

# define dictionary key names:
#
DEF_TIMESTART = 'timeStart'
DEF_TIMEEND = 'timeEnd'
DEF_SAMPLERATE = 'sampleRate'
DEF_LABEL = 'label'
DEF_CONFIDENCE = 'confidence'
DEF_DATA = 'data'

# set the filename using basename
#
__FILE__ = os.path.basename(__file__)

#------------------------------------------------------------------------------
#
# classes are listed here
#
#------------------------------------------------------------------------------

# class: Decoder
#
# This class is a Python implementation of our ResNet18 decoder.
#
class Nedc_Decoder:
    
    # method: Decoder::constructor
    #
    # arguments: none
    #
    # returns: none
    #
    # This method constructs a Decoder object and initializes
    # its internal data.
    #
    def __init__(self):
        
        # set the class name
        #
        Nedc_Decoder.__CLASS_NAME__ = self.__class__.__name__

        # set needed initial values
        #
        self.dict_st_t = 0
        self.process_ready = False
        self.previous_detection = None
    #
    # end of method
    
    # method: Decoder::init
    #
    # arguments:
    #  param_fname: path to a parameter file
    #  basename: the output csvbi file basename
    #  odir: the output csvbi file directory    
    #  rdir: the replace directory
    #
    # return: True if the models were successfully loaded
    #
    # This method initializes a Decoder object. It resets all relevant
    # parameters related to the decoder.
    #
    def init(self, param_fname, basename, odir, rdir):
        
        # load the parameter file
        #
        params = nft.load_parameters(param_fname, RESNET_DECODE)

        # read the parameter file
        #
        # decoder parameters
        #
        self.mon_order = eval(params[MONTAGE_ORDER])

        # check if we need to use the channel order or montage order
        #
        if self.mon_order is not None:
            self.ch_order = self.mon_order
        else:
            self.ch_order = eval(params[CHANNEL_ORDER])
            
        self.DEVICE = torch.device(params[DEVICE_STR]
                            if torch.cuda.is_available() else CPU_STR)
        self.samp_rate = int(float(params[SAMPLE_FREQUENCY]))
        self.mdl_fname = nft.get_fullpath(str(params[MDL_FNAME]))
        
        # set the montage file
        #
        montage_file = nae.DEFAULT_MONTAGE_FNAME

        # create the final csvbi filename
        #
        self.csv_fname = nft.create_filename(basename, odir, CSV_BI_EXT, rdir)

        # initialize a Csv object
        #
        self.ann = nae.Csv(montage_f = montage_file)
        
        # get values that need to be processed
        #
        transform_crop = int(params[TRANSFORM_CROP])
        frame_duration = float(params[FRAME_DURATION])

        # calculate the amount of samples needed per frame
        #
        self.frmsize = int(frame_duration * self.samp_rate)

        # calculate the transform_order for cropping the RGB image
        #
        transform_order = self.frmsize - transform_crop

        # get all the necessary image transformations, transform_order and
        # self.frmsize values will be used in the transformations
        #
        self.transforms = eval(params[TRANSFORMS])
        
        # post processor parameters
        #
        self.sec_fr_postp = float(params[WINSIZE])
        self.seiz_ths = float(params[SEIZ_TH])
        self.min_seiz_sec = float(params[MIN_SEIZ])
        self.min_bckg_sec = float(params[MIN_BCKG])
        
        # instantiate post processor
        #
        self.postprocesser = ppr.Nedc_Postprocess() 

        self.postprocesser.init(self.seiz_ths, self.frmsize,
                                self.samp_rate, self.sec_fr_postp,
                                self.min_seiz_sec, self.min_bckg_sec)
        
        # initialize dictionaries
        #
        self.buff = {}
        self.diction = {}
        for channel in self.ch_order:
            self.buff[channel] = []
            self.diction[channel] = []

        # set a boolean for updating the csvbi file
        #
        self.update_header = True
        
        # load model
        #
        model = torch.load(nft.get_fullpath(DEF_PRETRAINED_MODEL))
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, len(ppr.CLASSES))

        if self.mdl_fname is not None:
            model.load_state_dict(torch.load(self.mdl_fname,
                                             map_location=self.DEVICE))
        self.model = copy.deepcopy(model).to(self.DEVICE)
        
        # exit gracefully
        #
        return True
    #
    # end of method
    
    # method: Decoder::process
    #
    # arguments:
    #  sig: a window of a signal to be processed
    #
    # return: None
    #
    # This method starts the decoding process when we have enough data
    #
    def process(self, sig):
        
        # make sure signal isn't none
        #
        if sig == None:
            print("Decoder::process: unable to read signal")
            sys.exit(os.EX_SOFTWARE)

        # load a signal of data into the buffer
        #
        self.load_frame(sig)

        # check if diction has enough data
        #
        if self.process_ready == True:    

            # reset process_ready flag
            #
            self.process_ready = False
                
            wind_samps = []

            # create list with all the data
            #
            for chan in self.ch_order:
                wind_samps.append(self.diction[chan])

            # convert the data to a numpy array
            #
            sigs = np.asarray(wind_samps)

            # reset diction
            #
            for chan in self.ch_order:
                self.diction[chan] = []

            # decode the data
            #
            self.decode_signals(sigs)

        # exit gracefully
        #
        return None
    #
    # end of method

    # method: Decoder::flush
    #
    # arguments: None
    #
    # returns: None
    #
    # This method is called once no more data is available from stdin.
    # It terminates postprocessing and outputs any hypotheses in progress.
    #
    def flush(self):

        # check if theres data in the buffer
        #
        if len(self.buff[self.ch_order[0]]) != 0:
            for chan in self.ch_order:
                for x in self.buff[chan]:

                    # add buffer data to diction
                    #
                    self.diction[chan].append(x)

        # check if there is excess data in the dictionary
        #
        if len(self.diction[self.ch_order[0]]) != 0:

            # set the final end time
            #
            self.postprocesser.end_t_fin = self.dict_st_t + 1
            
            for chan in self.ch_order:
                
                # zero out the rest of the frame
                #
                while len(self.diction[chan]) < self.frmsize:
                    self.diction[chan].append(0.0)
            
            wind_samps = []

            # create a list with all the data
            #
            for chan in self.ch_order:
                wind_samps.append(self.diction[chan])

            sigs = np.asarray(wind_samps)

            # decode the excess data
            #
            self.decode_signals(sigs)

        # exit gracefully
        #
        return None
    #
    # end of method

    # method: Decoder::decode_signals
    #
    # arguments:
    #  sigs: numpy array of signals to be processed
    #
    # return: None
    #
    # This method is the main driver to decode a window of signals. It is
    # called by process and flush.
    #
    def decode_signals(self, sigs):
        
        # make dataset
        #
        dataset = self.get_image_dataset(sigs)
        
        # get seizure probabilities based on the dataset
        #
        all_probs = self.get_detects(dataset)

        # post-process the probability
        #
        if hyp := self.postprocesser.process(all_probs):

            # set a boolean for updating the csvbi file
            #
            self.update_header = False

            # write the csvbi data
            #
            self.write_csvbi_data(hyp)
        
        # exit gracefully
        #
        return None
    #
    # end of method

    # method: Decoder::get_detects
    #
    # arguments:
    #  dataset: image dataset to decode
    #
    # return:
    #  all_probs: list of all seizure probabilities
    #
    # This function converts an image dataset
    # to probabilities of siezure event
    #
    def get_detects(self, dataset):

        # save the model state to restore it at the end
        #
        was_training = self.model.training

        # set model to eval mode
        #
        self.model.eval()

        # extract the probabilities
        #
        with torch.no_grad():

            dataset = dataset.unsqueeze(dim=0)
            dataset = dataset.to(self.DEVICE)
            outputs = self.model(dataset)
            seiz_probs = torch.softmax(outputs, 1)[:, 1]
            all_probs = seiz_probs.cpu().numpy().tolist()
           
        # check if the length of the file is less than
        # self.frmsize
        #
        if len(all_probs) == 0:
            all_probs = 0.0           
        else:
            all_probs = all_probs[0]
            
        # change the model state to last state
        #
        self.model.train(mode=was_training)

        # exit gracefully
        #
        return all_probs
    #
    # end of method

    # method: Decoder::load_frame
    #
    # arguments:
    #  sig: signal input
    #
    # return: None
    #
    # This method adds a signal to the signal dictionary.
    #
    def load_frame(self, sig):

        self.dict_st_t = sig[DEF_TIMESTART] 

        # update the csv file
        #
        self.update_csv_dur(self.dict_st_t + 1)
   
        # check if theres data in the buffer
        #
        if len(self.buff[self.ch_order[0]]) != 0:
            for chan in self.ch_order:
                for x in self.buff[chan]:
                    
                    # add buffer data to diction
                    #
                    self.diction[chan].append(x)
                  
                # reset the buffer
                #
                self.buff[chan] = []
                
        for chan in self.ch_order:
            for x in sig[DEF_DATA][chan]:

                # check if diction is filled
                #
                if len(self.diction[chan]) == self.frmsize:
                    
                    # read the rest of sig into the buffer
                    #
                    self.buff[chan].append(x)
                    self.process_ready = True
                    
                else:

                    # add signal data to diction
                    #
                    self.diction[chan].append(x)

        # exit gracefully
        #
        return None
    #
    # end of method

    # method: Decoder::get_image_dataset
    #
    # argument:
    #  sigs: numpy array of sigs to be processed
    #
    # return:
    #  rgb: RGB image containing the signal data
    #
    # This method converts frm_len amount of a signal into an RGB image.
    #
    def get_image_dataset(self, sigs):

        # set the index, this is used to keep track of which samples to use
        #
        sig_mat = sigs

        # scaling the signals to gray scale images [0, 255]
        #
        sig_mat = sig_mat.T - sig_mat.min(axis = 1)
        image = np.uint8(np.transpose(
                sig_mat / (sig_mat.max(axis = 0) + EPSILON) * self.frmsize))

        # create and resize the grayscale image
        #
        gray = Image.fromarray(image)
        gray = gray.resize((int(self.frmsize), int(self.frmsize)))

        # convert the grayscale image to an RGB image
        #
        rgb = Image.merge(DEF_RGB, (gray, gray, gray))

        # apply the image transformations
        #
        if self.transforms:
          rgb = self.transforms(rgb)

        # make sure RGB image isn't none
        #
        if rgb == None:
            print("Decoder::get_image_dataset: unable to create image")
            sys.exit(os.EX_SOFTWARE)

        # exit gracefully
        #
        return rgb
    #
    # end of method

    # method: Decoder::write_csvbi_data
    #
    # argument:
    #  hyp: data to write to the csvbi files
    #
    # return:
    #  None
    #
    # This method writes the hypothesis to a csvbi file
    #
    def write_csvbi_data(self, hyp):
        
        # create a graph object of the hypothesis
        #
        self.ann.graph_d.create(int(0), int(0), int(-1),
                                float(hyp[DEF_TIMESTART]),
                                float(hyp[DEF_TIMEEND]),
                                {hyp[DEF_LABEL]: float(hyp[DEF_CONFIDENCE])})
        
        # write to the csvbi file
        #
        self.ann.write(self.csv_fname, int(0), int(0))

        # exit gracefully
        #
        return None
    #
    # end of method

    # method: Decoder::update_csv_dur
    #
    # argument:
    #  file_dur: the new csvbi file duration
    #
    # return:
    #  None
    #
    # This method updates the csvbi file duration
    #
    def update_csv_dur(self, file_dur):

        # set the file duration
        #
        self.ann.set_file_duration(float(file_dur))

        # check if we can just write the header
        #
        if self.update_header:

            # write to the csvbi file
            #
            self.ann.write_header(self.csv_fname)
            
        else:
            
            # write to the csvbi file
            #
            self.ann.write(self.csv_fname, int(0), int(0))

        # exit gracefully
        #
        return None
    #
    # end of method

#
# end of Decoder

#
# end of file
