#!/usr/bin/env python
#
# file: $(NEDC_NFC)/src/classes/nedc_eeg_resnet_train.py
#
# revision history:
#
# 20211118 (PM): refactored code
# 20210422 (VK): initial release
# 20221009 (ML): refactored code
#
# This file contains a python class that handles training for an EEG
# ResNet18 system. Details about this system can be found here:
#
#  Khalkhali, V., Shawki, N., Shah, V., Golmohammadi, M., Obeid, I., &
#  Picone, J. (2021). Low Latency Real-Time Seizure Detection Using
#  Transfer Deep Learning. In I. Obeid, I. Selesnick, & J. Picone (Eds.),
#  Proceedings of the IEEE Signal Processing in Medicine and Biology
#  Symposium (SPMB) (pp. 1–7). IEEE.
#  https://doi.org/10.1109/SPMB52430.2021.9672285
#  https://www.isip.piconepress.com/publications/conference_proceedings/2021/ieee_spmb/eeg_transfer_learning/
#
# The API is very simple:
#  constructor: creates the class (called at the top of the program)
#  train: the main method of the class, will run the entire model
#   training process (should be called after init)
#  decode: this method will decode a dataset and compute the confusion
#   matrix and accuracy (last method to be called to test model performance)
#
#------------------------------------------------------------------------------

#------------------------------------------------------------------------------

# import system modules
#
import os
import sys
import numpy as np
import time
import copy
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torch.utils as utils

# import nedc_modules
#
import nedc_eeg_image as nei
import nedc_file_tools as nft
import nedc_ann_eeg_tools as nae
import nedc_edf_tools as net
import nedc_edf_downsample as ned

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

# define value for CSV header constant
#
DEF_CSV_LABELS = "channel,start_time,stop_time,label,confidence"

# define values for common strings needed
#
TRAIN, DEV, EVAL = 'train', 'dev', 'eval'
BCKG, SEIZ = 'bckg', 'seiz'
TRANSFORM = 'transform'

# define the learning rate of the stochastic gradient descent function
#
LEARNING_RATE = 0.001

# define the momentum of the stochastic gradient descent function
#
MOMENTUM = 0.900

# define a variable for after how many epochs to decrease learning rate
# of the stochastic gradient descent function
#
STEP_SIZE = 7

# define the model extension
#
DEF_PCKL = '.pckl'

# define a variable for how much to decrease the learning rate after
# STEP_SIZE amount of epochs
#
GAMMA = 0.100

# set a variable for infinity
#
INFINITY = 'inf'

#------------------------------------------------------------------------------
#
# classes are listed here
#
#------------------------------------------------------------------------------

# class: EEGResNet
#
# This class can retrain a pretrained ResNet18
#
class EEGResNet:

    # method: EEGResNet::constructor
    #
    # arguments:
    #  train_edf_list: list of all training edf files
    #  train_csvbi_list: list of all training csv_bi files
    #  dev_edf_list: list of all dev edf files     
    #  dev_csvbi_list: list of all dev csv_bi files
    #  frmsize: number of samples per frame
    #  train_transforms: the image transformations on the train dataset
    #  dev_transforms: the image transformations on the dev dataset
    #  device: the device that the training will be run on
    #  channel_order: list of channels in the edf files
    #  new_samp_rate: the sample rate to downsample to
    #
    # returns: none
    #
    # This method constructs an EEGResNet object and initializes
    # its internal data.
    #
    def __init__(self, train_edf_list, train_csvbi_list, dev_edf_list,
                 dev_csvbi_list, frmsize, train_transforms, dev_transforms,
                 device, channel_order, new_samp_rate):

        # set the train dataset edf and csvbi files
        #
        self.train_edf_list = train_edf_list
        self.train_csv_list = train_csvbi_list

        # set the dev dataset edf and csvbi files
        #
        self.dev_edf_list = dev_edf_list  
        self.dev_csv_list = dev_csvbi_list

        # set the amount of samples per frame
        #
        self.frmsize = frmsize

        # set the device training will be run on
        #
        self.device = device

        # build the training dataset
        #
        self.train_data = self.build_dataset(self.train_edf_list,
                                             self.train_csv_list, new_samp_rate,
                                             channel_order)

        # build the dev dataset                                            
        #                                                                       
        self.dev_data = self.build_dataset(self.dev_edf_list,               
                                           self.dev_csv_list, new_samp_rate,
                                           channel_order)      
        
        # create the train and dev image datasets
        #
        self.train_dataset = nei.Images(self.train_data, self.frmsize)
        self.dev_dataset = nei.Images(self.dev_data, self.frmsize)
        
        # create a seperate variable for the class names of the dataset
        # (bckg, seiz)
        #
        self.class_names = self.train_dataset.classes

        # calculate number of classes in the dataset
        #
        nclasses = len(self.class_names)
        
        # Compute the number of samples in each class for train
        #
        self.train_nsamples = [np.sum(np.array(self.train_dataset.targets) == c)
                               for c in range(nclasses)]

        # Compute the number of samples in each class for dev                 
        #                                                                       
        self.dev_nsamples = [np.sum(np.array(self.dev_dataset.targets) == c)
                             for c in range(nclasses)]

        # assign the train image transformations
        #
        if hasattr(self.train_dataset, TRANSFORM):
            self.train_dataset.transform = train_transforms
        else:
            self.train_dataset.dataset.transform = train_transforms

        # assign the dev image transformations                         
        #                                                          
        if hasattr(self.dev_dataset, TRANSFORM):                 
            self.dev_dataset.transform = dev_transforms        
        else:                                                      
            self.dev_dataset.dataset.transform = dev_transforms
            
        # get the maximum amount of samples in the train dataset and divide
        # that by the tensor value of train_nsamples 
        #
        train_weights = (max(self.train_nsamples) /
                         torch.Tensor(self.train_nsamples))

        # get the maximum amount of samples in the dev dataset and divide
        # that by the tensor value of dev_nsamples 
        #
        dev_weights = (max(self.dev_nsamples) /
                       torch.Tensor(self.dev_nsamples))

        # set train_weights by dividing train_weights by the total
        # number of samples
        #
        self.train_weights = train_weights / train_weights.sum()

        # set dev_weights by dividing dev_weights by the total
        # number of samples
        #
        self.dev_weights = dev_weights / dev_weights.sum()

        # exit gracefully
        #
        return None
    #
    # end of method

    # method: EEGResNet::train
    #
    # argument:
    #  initial_model: the pretrained model that should be trained
    #  batch_size: the training batch size
    #  num_epochs: the number of epochs for training
    #  nworkers: number of workers to load data
    #  save_epoch_model: boolean value for saving a model each epoch
    #  out_model: the output model filename
    #  input_model_fname: initial model to load data from
    #
    # return:
    #  model: the trained model
    #
    # this method is the main method of this class, this method will
    # run the model training process num_epochs times
    #
    def train(self, initial_model, batch_size, num_epochs, nworkers,
              save_epoch_model, out_model, input_model_fname=None):

        # create the train dataloader, this function creates an iterable
        # type that will allow us to iterate over the dataset much faster
        #
        self.train_dataloader = \
            torch.utils.data.DataLoader(self.train_dataset, batch_size,
                                        shuffle=True, num_workers=nworkers)

        # create the dev dataloader, this function creates an iterable
        # type that will allow us to iterate over the dataset much faster
        #
        self.dev_dataloader = \
            torch.utils.data.DataLoader(self.dev_dataset, batch_size, 
                                        shuffle=False, num_workers=nworkers)
        
        # load model (and weights if available)
        #
        model_ft = torch.load(os.path.expandvars(initial_model))
        if input_model_fname is not None:
            model_ft.load_state_dict(torch.load(input_model_fname,
                                                map_location=self.device))
            
        # configure the last layer output
        #
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, len(self.class_names))

        # send the model to device
        #
        model_ft = model_ft.to(self.device)

        # define the cross entropy loss function for the dev and train datasets,
        # this will allow us to calculate the  error the model is generating,
        # and allow the model to adjust weights to find a better performance
        #        
        train_criterion = nn.CrossEntropyLoss(weight=
                                             self.train_weights.to(self.device))
        dev_criterion = nn.CrossEntropyLoss(weight=
                                            self.dev_weights.to(self.device))
        
        # observe that all parameters are being optimized by using a stochastic
        # gradient descent algorithm, which will allow us to find optimum points
        # of a function by testing values, lr (learning rate) is the distance
        # between each point tested, momentum will monitor the values so that
        # we are moving in the correct direction to find the optimum point
        #
        optimizer_ft = \
            torch.optim.SGD(model_ft.parameters(), lr = LEARNING_RATE,
                            momentum = MOMENTUM)

        # lr_scheduler.StepLR will allow us to adjust the learning rate of the
        # stochastic gradient descent algorithm by subtracting 0.1 (the gamma
        # value) every 7 epochs (the step_size)
        #
        exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft,
                                               step_size = STEP_SIZE,
                                               gamma = GAMMA)

        # train the model using __train_model
        #
        self.model = self.__train_model(model_ft, train_criterion,
                                        dev_criterion, optimizer_ft,
                                        exp_lr_scheduler, self.train_dataloader,
                                        self.dev_dataloader, self.device,
                                        save_epoch_model, out_model, num_epochs)

        # exit gracefully
        #
        return(self.model)
    #
    # end of method

    # method: EEGResNet::decode
    #
    # arguments:
    #   model: neural network model
    #   dataloader: the dataloaders dictionary
    #   device: device that decoding will be run on
    #
    # return:
    #   all_labels: list of correct classifications (bckg or seiz)
    #   all_preds: list of predicted classifications (bckg or seiz)
    #   accuracy: the total classification accuracy
    #
    # This function will compute the values needed to compute the
    # confusion matrix and accuracy
    #
    def decode(self, model, dataloader, device):

        # save the model state to restore it at the end
        #
        was_training = model.training        
        model.eval()

        # extract all labels and predictions for passing to confusion matrix
        #
        all_labels = np.zeros(len(dataloader) * dataloader.batch_size)
        all_preds = np.zeros(len(dataloader) * dataloader.batch_size)

        # loop over all epochs
        #
        lcounter = 0

        # use torch.no_grad() to specify torch to not calculate the gradients
        #
        with torch.no_grad():

            # iterate through the dataloader
            #
            for i, (inputs, labels) in enumerate(dataloader):

                # send the inputs and labels to the device
                #
                inputs = inputs.to(device)
                labels = labels.to(device)

                # send the inputs to the model
                #
                outputs = model(inputs)

                # get the maximum confidence
                #
                _, preds = torch.max(outputs, 1)

                # update all_labels to have the correct labels
                #
                all_labels[lcounter: lcounter + len(labels)] = \
                    labels.cpu().numpy()

                # update all_preds to have the correct predicted values
                #
                all_preds[lcounter: lcounter + len(preds)] = preds.cpu().numpy()

                # increment lcounter
                #
                lcounter += len(labels)

        # change the model state to last state
        #
        model.train(mode=was_training)

        # computing accuracy based on simple mean absolute error
        #
        accuracy = np.mean(all_labels == all_preds)

        # exit gracefully
        #
        return all_labels, all_preds, accuracy
    #
    # end of method
    
    # method: EEGResNet::__train_model
    #
    # arguments:
    #   model: the neural network model which should be trained
    #   train_criterion: loss function for training (such as mse or
    #     cross entropy)
    #   dev_criterion: loss function for development (such as mse or
    #     cross entropy)
    #   optimizer: optimizer function (such as SGD or Adam)
    #   scheduler: shceduler object to change the optimizers parameters
    #   train_dataloader: the train dataloader
    #   dev_dataloader: the dev dataloader
    #   device: the device training will be run on (gpu or cpu)
    #   save_epoch_model: boolean value to determine saving a model every epoch
    #   out_model: the output model filename
    #   num_epochs: number of epochs which neural network will be trained
    #
    # return:
    #   model: the trained model
    #
    # This function accepts a model and trains it.
    #
    def __train_model(self, model, train_criterion, dev_criterion, optimizer,
                      scheduler, train_dataloader, dev_dataloader, device,
                      save_epoch_model, out_model, num_epochs=1):

        # compute batch sizes for the train and dev dataloader
        #
        train_len = len(train_dataloader) * train_dataloader.batch_size
        dev_len = len(dev_dataloader) * dev_dataloader.batch_size
        
        # keep track of the start time
        #
        since = time.time()

        # store the state_dict using deepcopy
        #
        best_model_wts = copy.deepcopy(model.state_dict())

        # set the initial accuracy as 0.0
        #
        best_acc = 0.0

        # set the initial loss as infinity
        #
        best_loss = float(INFINITY)

        # iterate through amount of epochs
        #
        for epoch in range(num_epochs):

            # keep track of the time for each epoch
            #
            epoch_time = time.time()

            # print out which epoch is currently running
            #
            print(f'Epoch {epoch+1}/{num_epochs}')

            # print out a line to make the output look nice
            #
            print(nft.DELIM_DASH * 10)

            # Each epoch has a training and validation phase
            #
            # set the model to training mode for the training phase
            #
            model.train()

            # set initial value for the running loss as 0
            #
            running_loss = 0.0

            # set initial value for running correct answers as 0
            #
            running_corrects = 0

            # iterate over the train dataloader
            #
            for inputs, labels in train_dataloader:

                # send the inputs and labels to the device
                #
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                #
                optimizer.zero_grad()

                # start the forward pass (flow of information
                # from the input to the output of the neural network)
                # set set_grad_enabled to true to allow torch to
                # calculate the gradients
                #
                with torch.set_grad_enabled(True):

                    # send the inputs to the model
                    #
                    outputs = model(inputs)

                    # get the maximum confidence
                    #
                    _, preds = torch.max(outputs, 1)

                    # calculate the loss using the cross entropy loss function
                    #
                    loss = train_criterion(outputs, labels)

                    # start the backward pass (adjusting model weights)
                    #
                    # use loss.backward() to compute gradients
                    #
                    loss.backward()

                    # update the model parameters
                    #
                    optimizer.step()

                # calculate the statistics, running loss, and the running
                # corrects
                #
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            # execute the scheduler
            #
            scheduler.step()

            # calculate the loss and accuracy for this epoch
            #
            epoch_loss = running_loss / train_len
            epoch_acc = running_corrects.double() / train_len

            # print a message with the elapsed time, the loss, and
            # the accuracy for this epoch
            #
            print(f'Train \t Elapsed: {time.time()-epoch_time:.2f} sec '  
                  + f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
 
            # check if we want to save model parameters for each epoch
            #
            if save_epoch_model:

                # get the output model file path
                #
                file_path = os.path.dirname(out_model)

                # extract the file extension
                #
                file_ext = os.path.splitext(out_model)[1][1:]

                # get the output model basename
                #
                file_bname = os.path.splitext(os.path.basename(out_model))[0]

                # append the basename to include the epoch value
                #
                file_bname = file_bname + nft.DELIM_USCORE + str(epoch + 1)

                # create the final model filename for this epoch
                #
                epoch_fname = nft.create_filename(file_bname, file_path,
                                                  file_ext, None)
 
                # save the epoch model
                #
                torch.save(model.state_dict(), epoch_fname)

            # set the model to eval mode for the validation phase
            #
            model.eval()

            # reset running loss 
            #
            running_loss = 0.0

            # reset running corrects
            #
            running_corrects = 0

            # iterate over the dev dataloader
            #
            for inputs, labels in dev_dataloader:

                # send the inputs and labels to the device
                #
                inputs = inputs.to(device)
                labels = labels.to(device)

                # set grad enabled to false as we won't be changing gradients
                #
                with torch.set_grad_enabled(False):

                    # send the inputs to the model
                    #
                    outputs = model(inputs)

                    # get the maximum confidence
                    #
                    _, preds = torch.max(outputs, 1)

                    # calculate the loss using the cross entropy loss function
                    #
                    loss = dev_criterion(outputs, labels)

                # calculate the statistics, running loss, and the running
                # corrects
                #
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            # calculate the loss and accuracy for this epoch
            #
            epoch_loss = running_loss / dev_len
            epoch_acc = running_corrects.double() / dev_len

            # print a message with the elapsed time, the loss, and
            # the accuracy for this epoch
            #
            print(f'Devel \t Elapsed: {time.time()-epoch_time:.2f} sec '
                  + f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # check if this epoch improved the model by seeing if this
            # epoch decreased the loss
            #
            if epoch_loss < best_loss:
                
                # set values for best accuracy and loss to this
                # epoch's values
                #
                best_acc = epoch_acc
                best_loss = epoch_loss

                # store the best model's weights by using deepcopy
                #
                best_model_wts = copy.deepcopy(model.state_dict())

        # print a single line
        #
        print()

        # calculate the total time
        #
        time_elapsed = time.time() - since

        # print a final message with total training time
        #
        print(f'Training completed in {time_elapsed // 60:.0f}m ' +
              f'{time_elapsed % 60:.0f}s')

        # load the best model weights
        #
        model.load_state_dict(best_model_wts)

        # exit gracefully
        #
        return model
    #
    # end of method

    # method: EEGResNet::build_dataset
    #
    # arguments:
    #   edf_list: list of all edf files
    #   csv_bi_list: list of all csv_bi files
    #   new_samp_rate: sample rate to downsample to
    #   channel_order: list of channels from edf files
    #
    # return:
    #   data_set: the final dictionary containing numpy array datasets
    #
    # This function creates a dataset with edf and csv_bi files
    #
    def build_dataset(self, edf_list, csv_bi_list, new_samp_rate,
                      channel_order):
        
        # initiate EDF class
        #
        Edf = net.Edf()

        # instantiate Edf_Downsample class
        #
        dsample = ned.Edf_Downsample()
        
        # create lists to hold the edf and csvbi data
        #
        montages_data = []
        csv_bi_data = []

        # iterate through both edf and csvbi lists
        #
        for edf, csv in zip(edf_list, csv_bi_list):

            # get the edf file signals and sample rate
            #
            samp_rate, sigs = Edf.channels_data(edf, channel_order)

            # down sample the edf data
            #
            lpsig = dsample.downsample(sigs, samp_rate, new_samp_rate, 
                                       filter_order = self.frmsize // 2)

            # read the csvbi data
            #
            ends = self.read_csvbi(csv)

            # add the edf and csvbi data into the lists
            #
            montages_data.append(lpsig)
            csv_bi_data.append(ends) 
        
        # Merge edf and csvbi data into a single dictionary
        # 
        data_set = {BCKG: [], SEIZ: []}

        # iterate through both datasets
        #
        for m_data, csv_data in zip(montages_data, csv_bi_data):

            # break the csvbi data into slices
            #
            slices = self.csvbi_slice(csv_data, new_samp_rate)

            # enumerate through data_set
            #
            for evc, event in enumerate(data_set):

                # iterate through each slice
                #
                for slc in slices[evc]:

                    # fill in data_set
                    #
                    data_set[event].append(m_data[:, slc[0]: slc[1]])

        # find the number of rows for each list element for each key
        #
        nrows = data_set[BCKG][0].shape[0]

        # find the bckg and seiz amount of columns
        #
        bckg_ncols = sum([mat.shape[1] for mat in data_set[BCKG]])
        seiz_ncols = sum([mat.shape[1] for mat in data_set[SEIZ]])

        # concatenate the list into a numpy array
        #
        matrix = {BCKG: np.zeros((nrows, bckg_ncols)),
                SEIZ: np.zeros((nrows, seiz_ncols))}

        # fill the background and seizure matrices
        #
        # iterate through each key in the data_set dictionary
        #
        for key in data_set:

            # reset the last column size
            #
            last_col = 0

            # iterate through the data in this key
            #
            for mat in data_set[key]:

                # add the data into the array
                #
                matrix[key][:, last_col:last_col + mat.shape[1]] = mat

                # keep track of how many samples are in the column
                #
                last_col += mat.shape[1]

        # set the final dataset
        #
        data_set = matrix

        # exit gracefully
        #
        return data_set
    #
    # end of method
    
    # method: EEGResNet::read_csvbi
    #
    # arguments:
    #   csvbi_fpath: csv_bi full file path and filename
    #
    # return:
    #   ends: list that contains the end points of every event
    #
    # This method reads a csvbi file by assuming that the fourth column is the
    # end-time and the sequence of events start and finish with bckg event
    #
    def read_csvbi(self, csvbi_fpath):

        # list that will hold the end points
        #
        ends = []

        # get the header 
        #
        header_comments = nft.extract_comments(csvbi_fpath)

        # open the file
        #
        with open(csvbi_fpath, mode=nft.MODE_READ_TEXT) as file:

            # iterate through each line in the file
            #
            for line in list(file.readlines()):

                # ignore lines that start with #
                #
                if line.startswith(nft.DELIM_COMMENT) or \
                DEF_CSV_LABELS in line:
                    continue

                # split the line by comma
                #
                _, t0, t1, ev, _ = line.strip().split(nft.DELIM_COMMA)

                # convert the start and end times to a float
                #
                t0, t1 = float(t0), float(t1)

                # if this is a seizure event add start and stop times to ends
                if ev == SEIZ:
                    ends += [t0, t1]

        # get the total duration of the file
        #
        duration = header_comments[nae.CKEY_DURATION]
        duration = float(duration[:duration.find(nft.DELIM_DOT)])

        # check if there are no seiz events
        #
        if len(ends) == 0:
            ends = [duration]

        # check if we are missing an event
        #
        elif ends[-1] < duration:
            ends.append(duration)

        # exit gracefully
        #
        return ends
    # 
    # end of method

    # method: EEGResNet::csvbi_slice
    #
    # arguments:
    #   csvbi_data: csv_bi files path
    #   samp_freq: the sample rate of the data
    #
    # return:
    #   slices: csvbi data broken up into slices
    #
    # This function converts all timed events into samples indices and returns
    # the slices
    #
    def csvbi_slice(self, csvbi_data, samp_freq):

        # first list belongs to background and second list belongs to seizure
        #
        slices = [[], []]

        # keep track of the current event and the last event
        #
        curr_event = 0
        last_event_index = 0

        # iterate through the data
        #
        for event_time in csvbi_data:

            # calculate the index of the current event
            #
            curr_event_index = int(event_time * samp_freq)

            # add this event to the slices list
            #
            slices[curr_event].append((last_event_index, curr_event_index))

            # reset current and last event
            #
            curr_event = 1 - curr_event
            last_event_index = curr_event_index

        # exit gracefully
        #
        return slices
    #
    # end of method

#
# end of class

#
# end of file
