#!/usr/bin/env python
#
# file: $(NEDC_NFC)/src/classes/nedc_eeg_image.py
#
# revision history:
#
# 20211118 (PM): refactored code
# 20210426 (VK): added a new class to handle single EEG pickle
#                for decoding purpose.
# 20210422 (VK): initial release
# 20221020 (ML): refactored code
#
# This file contains a class for making image datasets from a single EDF file,
# and a class for making an image dataset from edf and csvbi data.
#------------------------------------------------------------------------------

# import system modules
#
import os
import sys
import numpy as np
import pickle
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils as utils
from torchvision import datasets, models, transforms, io

# import nedc_modules
#
import nedc_edf_tools as net
import nedc_edf_downsample as ned

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------
#

# define the binary options, 0 = bckg 1 = seiz
#
BCKG, SEIZ = 'bckg', 'seiz'
CLASSES = [BCKG, SEIZ]

# define a constant to scale the signal data to [0, 255]
# in order to create a grayscale image
#
EPSILON = 1e-6

# define a constant for reading montages_data
#
MONTAGES_DATA = "montages_data"

# define a constant for the type of image the grayscale image will
# be converted to
#
DEF_RGB = 'RGB'

# create a constant for reading binary data
#
DEF_RB = 'rb'

#------------------------------------------------------------------------------
#
# classes are listed here
#
#------------------------------------------------------------------------------

# class: Images
# 
# this class will convert a dataset (edf and csvbi data in seperate numpy arrays
# contained in a dictionary) into an image dataset
#
class Images(utils.data.Dataset):

    # method: Images::contructor
    #
    # arguments:
    #   dataset: a dictionary of background and seizure Numpy arrays.
    #   frmsize: amount of samples per frame 
    #   transform: all the transformations that should be done on images.
    #
    # return: none
    #
    # this is the constructor for Images, it will set important class data
    #
    def __init__(self, dataset, frmsize, transform=None):

        # classes will hold a list of the name of each class (bckg, seiz)
        #
        self.classes = CLASSES

        # set the dataset
        #
        self.samples = dataset

        # frmsize is the samples per frame, a single hypothesis is created
        # from frmsize amount of samples
        # 
        self.frmsize = frmsize

        # self.transform will be the image transformations to be done
        #
        self.transform = transform

        # compute the amount of images we will create for bckg
        # and seiz
        #
        self.len_bckg = (self.samples[BCKG].shape[1] - self.frmsize) \
            // self.frmsize
        self.len_seiz = (self.samples[SEIZ].shape[1] - self.frmsize) \
            // self.frmsize

        # compute the total amount of images for bckg and seiz
        #
        self.length = self.len_bckg + self.len_seiz

        # create an array of the target values
        #
        self.targets = np.array([0] * self.len_bckg + [1] * self.len_seiz,
                                dtype = int)

        # exit gracefully 
        #
        return None
    #
    # end of method

    #--------------------------------------------------------------------------
    #
    # set/get methods
    #
    #--------------------------------------------------------------------------

    # method: Images::__len__
    #
    # arguments: None
    #
    # return:
    #  number of images needed for bckg and seiz
    #
    # this simple method will return a value for total amount
    # of images needed
    #
    def __len__(self):

        # exit gracefully
        #
        return self.length
    #
    # end of method

    # method: Images::__getitem__
    #
    # arguments:
    #  idx: current index of samples
    #
    # return:
    #  rgb: RGB image
    #  label: label of sample (e.g. 0=bckg or 1=seiz)
    #
    # this is the main method of this class, this method will iterate
    # through a pickle file and return an image dataset
    #
    def __getitem__(self, idx):

        # caste idx as an int
        #
        idx = int(idx)

        # compare idx with lengths to find which part has been requested
        #
        if idx < self.len_bckg:

            # get the background signal matrix
            #
            sig_mat = np.double(self.samples[BCKG][:, idx * self.frmsize:
                                idx * self.frmsize + self.frmsize])

        else:
            
            # get the seizure signal matrix
            #
            ind = (idx - self.len_bckg) * self.frmsize
            sig_mat = np.double(self.samples[SEIZ][:, ind: ind + self.frmsize])

        # find the label of this current sample
        #
        label = self.targets[idx]

        # scaling the signal to gray scale images [0, 255]
        #
        sig_mat = sig_mat.T - sig_mat.min(axis = 1)
        image = np.uint8(np.transpose(
            sig_mat / (sig_mat.max(axis = 0) + EPSILON) * self.frmsize))
        
        # create and resize the grayscale image
        #
        gray = Image.fromarray(image)
        gray = gray.resize((self.frmsize, self.frmsize))

        # convert the grayscale image to an RGB image
        #
        rgb = Image.merge(DEF_RGB, (gray, gray, gray))

        # apply the image transformations if they are not None
        #
        if self.transform:
          rgb = self.transform(rgb)

        # exit gracefully
        #
        return rgb, label
    #
    # end of method

#
# end of class

# class: SingleEDFImages
#
# This class will convert a single resampled EEG EDF file into an image dataset
#
class SingleEDFImages:

    # method: SingleEDFImages::contructor
    #
    # arguments:
    #   edf_fname: the name of EEG EDF file
    #   channels_oder: list of the channel names in order
    #   new_samp_rate: the frequency of the resamples EDF file
    #   frmsize: the amount of samples per frame
    #   transform: all the transformations that should be done on images.
    #
    # return: none
    #
    # this method is the constructor for the class, it will set important
    # class values
    #
    def __init__(self, edf_fname, channels_order, new_samp_rate, frmsize,
                 transform=None):

        # edf_fname is the EDF filename
        #
        self.edf_fname = edf_fname

        # classes will hold a list of the name of each class (bckg, seiz)
        #
        self.classes = CLASSES

        # frmsize is the samples per frame, a single hypothesis is created
        # from frmsize amount of samples
        #
        self.frmsize = frmsize

        # self.transform is the image transformations to be done
        #
        self.transform = transform

        # create an instance of the EEG class
        #
        eeg = net.Edf()

        # create an instance of Edf_Downsample class
        #
        dsample = ned.Edf_Downsample()
        
        # get the current sample rate and signals from the
        # edf file
        #
        samp_rate, sigs = eeg.channels_data(edf_fname, channels_order)

        # calculate the filter order (frmsize / 2)
        #
        filter_order = int(self.frmsize // 2)

        # down sample the signals to new_samp_rate frequency
        #
        lpsigs = dsample.downsample(sigs, samp_rate, new_samp_rate,
                                    filter_order)

        # self.samples will hold the down sampled signals
        #
        self.samples = lpsigs

        # compute amount of images we will create
        #
        self.length = \
           (self.samples.shape[1] - self.frmsize + self.frmsize) // self.frmsize

        # exit gracefully
        #
        return None
    #
    # end of method

    #--------------------------------------------------------------------------
    #
    # set/get methods
    #
    #--------------------------------------------------------------------------

    # method: SingleEDFImages::__len__
    #
    # arguments: None
    #
    # return:
    #  number of images needed
    #
    # this simple method will return the total amount of images needed for
    # the edf file
    #
    def __len__(self):

        # exit gracefully
        #
        return self.length

    #
    # end of method
    
    # method: SingleEDFImages::__getitem__
    #
    # arguments:
    #  idx: index of sample
    #
    # return:
    #  rgb: RGB image
    #
    # this is the main method of the class, it will iterate through signals and
    # create an image dataset
    #
    def __getitem__(self, idx):

        # caste the index as an integer
        #
        idx = int(idx)

        # create the signal matrix
        #
        sig_mat = self.samples[:, idx * self.frmsize:\
                               idx * self.frmsize + self.frmsize]

        # scaling the signal to gray scale images [0, 255]
        #
        sig_mat = sig_mat.T - sig_mat.min(axis = 1)
        image = np.uint8(np.transpose(
            sig_mat / (sig_mat.max(axis = 0) + EPSILON) * self.frmsize))

        # create and resize the grayscale image
        #
        gray = Image.fromarray(image)
        gray = gray.resize((int(self.frmsize), int(self.frmsize)))

        # convert the grayscale image to an RGB image
        #
        rgb = Image.merge(DEF_RGB, (gray, gray, gray))

        # apply the image transformations if they are not None
        #
        if self.transform:
          rgb = self.transform(rgb)

        # exit gracefully
        #
        return rgb
    #
    # end of method

    # method: SingleEDFImages::__iter__
    #
    # arguments:
    #  None
    #
    # return:
    #  self: updated class object
    #
    # this simple method will reset the counter to 0
    #
    def __iter__(self):

        # reset the counter to 0
        #
        self.counter = 0

        # exit gracefully
        #
        return self
    #
    # end of method

    # method: SingleEDFImages::__next__
    #
    # arguments:
    #  None
    #
    # return:
    #  self.__get__item: RGB image
    #
    # this method will control iteration and end iteration when we
    # have all of the images needed
    #
    def __next__(self):

        # make sure we haven't gone through the entire file
        #
        if self.counter < self.length:

            # increment the counter
            #
            self.counter += 1

            # exit gracefully
            #
            return self.__getitem__(self.counter - 1)

        # if we have created enough images
        #
        else:

            # call stopiteration to stop creating images
            #
            raise StopIteration
    #
    # end of method
#
# end of class

#
# end of file
