#!/usr/bin/env python

# file: postprocess.py
#                                                                              
# revision history:
#
# 20230320 (ML): initial version                                              
#
# This file contains a Python implementation of our real-time postprocesor.
# Details about this system can be found here:
#
#  Khalkhali, V., Shawki, N., Shah, V., Golmohammadi, M., Obeid, I., &
#  Picone, J. (2021). Low Latency Real-Time Seizure Detection Using
#  Transfer Deep Learning. In I. Obeid, I. Selesnick, & J. Picone (Eds.),
#  Proceedings of the IEEE Signal Processing in Medicine and Biology
#  Symposium (SPMB) (pp. 1–7). IEEE.
#  https://doi.org/10.1109/SPMB52430.2021.9672285
#  https://www.isip.piconepress.com/publications/conference_proceedings/2021/ieee_spmb/eeg_transfer_learning/
#
# The API is very simple:
#  constructor: creates the class (called at the top of the program)
#  init: initializes ths system (can be called any time)
#  process: this is the only function that needs to be called to postprocess
#     data, it takes a confidence and adds it to a list, once the list has
#     enough data it will return only seizure events to the decoder class
#
# Postprocessing adds a significant amount of delay (about 150 secs).
# Postprocessing can be tuned to minimize delay.
#
# After postprocessing the program will return to decoder.py either None value
# (if the event is a background event) or a python dictionary
# (if the event is a seizure event) , the format of the dictionary is:
# 
#  {'timeStart': '5.5000', 'timeEnd': '9.5000',
#    'label': 'seiz', 'confidence': '0.9530'}
#
# All values are rounded to 4 decimal places and the seizure confidence is
# calculated by averaging the confidence of all seizure events merged into
# a single event.
#
#------------------------------------------------------------------------------

# import required system modules                                               
#
import os
import sys
import numpy as np

#------------------------------------------------------------------------------
#                                                                              
# global variables are listed here
#                                                                              
#------------------------------------------------------------------------------

# decoding data
#
BCKG, SEIZ = 'bckg', 'seiz'
CLASSES = [BCKG, SEIZ]

# define dictionary key names
#
DEF_TIMESTART = 'timeStart'
DEF_TIMEEND = 'timeEnd'
DEF_CONFIDENCE = 'confidence'
DEF_LABEL = 'label'
DEF_DATA = 'data'
DEF_OUTPUT_FORMAT = '.4f'

# set the filename using basename
#
__FILE__ = os.path.basename(__file__)

#------------------------------------------------------------------------------
#
# functions are listed here
#
#------------------------------------------------------------------------------

#------------------------------------------------------------------------------
#
# classes are listed here
#
#------------------------------------------------------------------------------

# class: Postprocess
#
# This class postprocess signal that is outputted by the decoder
#
class Nedc_Postprocess:

    # method: Postprocess::constructor
    #
    # arguments:
    #   None
    #
    # returns: none
    #
    # This method simply constructs a Postprocessor object.
    #
    def __init__(self):
        
        # set the class name
        #
        Nedc_Postprocess.__CLASS_NAME__ = self.__class__.__name__
        
        # initialize class internal data used to save state
        #
        self.end_t = 0
        self.st_t = 0
        self.end_t_fin = None
        self.previous_detection = None
        self.postp_buff = []

    #
    # end of method

    # method: Postprocess::init
    #
    # arguments:
    #   seiz_ths: seizure threshold
    #   frmsize: frame size
    #   samp_rate: sample rate
    #   sec_fr_postp: second for post processing
    #   min_seiz_sec: minimum seizure seconds
    #   min_bckg_sec: minimum background seconds
    #
    # returns: none
    #
    # This method initializes the postprocessor. It is usually called
    # at the start of file processing, though it can be called any time.
    #
    def init(self, seiz_ths, frmsize, samp_rate,
             sec_fr_postp, min_seiz_sec, min_bckg_sec):

        # set the class internal data
        #
        self.seiz_ths = seiz_ths
        self.min_seiz_sec = min_seiz_sec
        self.min_bckg_sec = min_bckg_sec
        self.frmsize = frmsize
        self.samp_rate = samp_rate
        self.sec_fr_postp = sec_fr_postp

        # calculate the minimum frame that each event must reach
        #
        self.postp_frmsize = np.ceil(sec_fr_postp /
                                     (self.frmsize / self.samp_rate))

        self.min_seiz_frm = int(np.ceil(self.min_seiz_sec /
                                        (self.frmsize / self.samp_rate)))

        self.min_bckg_frm = int(np.ceil(2 * (self.min_bckg_sec /
                                             (self.frmsize / self.samp_rate))))
    #
    # end of method
    
    # method: Postprocess::process
    #
    # arguments:
    #  seiz_c: seizure confidence
    #
    # return:
    #   hyp: the final decision of the post processor
    #
    # This method is the only method that needs to be called when post
    # processing a signal. It returns an object if a segment has been
    # found. Otherwise, it returns None.
    #
    def process(self, seiz_c):

        # add the probability to the postp buffer
        #
        self.postp_buff.append(seiz_c)

        # if file is less than the postp buffer size zero out the buffer
        #
        if self.end_t_fin != None:
            while len(self.postp_buff) < self.postp_frmsize:
                self.postp_buff.append(0.0)

        # if the postp buffer is full output a detection
        #
        if len(self.postp_buff) == self.postp_frmsize:

            # get the postprocessed detection, this will either be None
            # or a python dictionary 
            #
            hyp = self.output_detection()

            # exit gracefully
            #
            return hyp
    #
    # end of method
    
    # method: Postprocess::output_detection
    #
    # arguments:
    #
    # return:
    #  hyp: the final decision of the post processor 
    #
    # This method converts a list of probabilities to events 
    # This is meant to be a private method and not part of the API.
    #
    def output_detection(self):

        # set necessary values
        #
        prev_bin = 0
        keep_processing = True

        # postp_bin_buff will hold the binary detections
        #
        postp_bin_buff = []

        # convert confidences to binary detections
        #
        for x in self.postp_buff:
            if x < self.seiz_ths:
                postp_bin_buff.append(0)
            else:
                postp_bin_buff.append(1)

        # convert the data to a numpy array
        #
        postp_bin_buff = np.array(postp_bin_buff)

        # check if the first value is the same as the last detect
        #
        if postp_bin_buff[0] == self.previous_detection:

            # set the keep_processing 
            #
            keep_processing = False

            # create the python dictionary or None if the event is bckg
            #
            hyp = self.create_hyp(postp_bin_buff[:1], self.postp_buff[0],
                                  postp_bin_buff[0])

            # delete the detection
            #
            del self.postp_buff[0]

        # continue processing if the data is ready
        #
        if keep_processing == True:

            # postprocess the binary array
            #
            postp_bin_buff = self.postprocess(postp_bin_buff)
            
            conf_list = []
            
            # go through each value in the postprocessed array
            #
            for count, detect in enumerate(postp_bin_buff):

                # check if binary detect is different than previous detect
                #
                if prev_bin != detect:
                    
                    # if the previous detection is 0 write the bckg event
                    #
                    if prev_bin == 0:

                        # check if the event is valid
                        #
                        if count >= self.min_bckg_frm / 2:

                            # create the python dictionary or none if the event
                            # is bckg
                            #
                            hyp = self.create_hyp(postp_bin_buff[:count], None,
                                                  prev_bin)

                            # delete the event and reset prev_bin
                            #
                            del self.postp_buff[:count]
                            prev_bin = detect 

                            break
                            
                    # if the previous detection is 1 write the seiz event
                    #
                    if prev_bin == 1:

                        # check if the event is valid
                        #
                        if count >= self.min_seiz_frm:

                            # create list of seiz confidences
                            #
                            for x in self.postp_buff[:count]:
                                if x >= self.seiz_ths:
                                    conf_list.append(x)

                            # get the average confidence
                            #
                            conf = (sum(conf_list) / len(conf_list))

                            # create the python dictionary or none if the event
                            # is background
                            #
                            hyp = self.create_hyp(postp_bin_buff[:count], conf,
                                                  prev_bin)

                            # delete the event and reset prev_bin
                            #
                            del self.postp_buff[:count]
                            prev_bin = detect

                            break
                        
                # check if we are at the end of the postp window
                #
                if count == len(postp_bin_buff) - 1:
                    
                    # check if the detect is bckg
                    #
                    if detect == 0:

                        # create the python dictionary or none if the event is
                        # background
                        #
                        hyp = self.create_hyp(postp_bin_buff[:count], None,
                                              detect)

                        # delete the event and reset prev_bin
                        #
                        del self.postp_buff[:count]
                        prev_bin = detect

                        break
                    
                    # check if the detect is seiz
                    #
                    if detect == 1:

                        # create a list of seiz confidences
                        #
                        for x in self.postp_buff[:count]:
                            if x >= self.seiz_ths:
                                conf_list.append(x)

                        # get the average confidence
                        #
                        conf = (sum(conf_list) / len(conf_list))

                        # create the python dictionary or none if the event is
                        # background
                        #
                        hyp = self.create_hyp(postp_bin_buff[:count], conf,
                                              detect)

                        # delete the event and reset prev_bin
                        #
                        del self.postp_buff[:count]
                        prev_bin = detect

                        break

                # reset prev_bin
                #
                prev_bin = detect

        # return the python dictionary or None if the event is background
        #
        return hyp

    #
    # end of method

    # method: Postprocess::postprocess
    #
    # arguments:
    #  dets: a signal list
    #
    # return:
    #  dets: a signal list that has been postprocessed
    #
    # This method will postprocess a list of signals by
    # converting bckg events to seiz events if the bckg
    # is less than 2 * min_bckg_sec seconds long.
    #
    def postprocess(self, dets):

        # set necessary local variables
        #
        all_ones = True
        count = 0
        bckg_cnt = 0

        # curr_state is a binary value that is a value of 1 when it is
        # a seizure.
        #
        curr_state = dets[0]

        # go through the entire window
        #
        while count < len(dets):

            # check if this is the first value in the window
            #
            if count != 0:

                # check if prev detect is 0 and current detect is 1
                #
                if dets[count - 1] == 0 and dets[count] == 1:

                    # set the all_ones flag to false
                    #
                    all_ones = False

            # check if prev value is 0 and curr value is 1 and
            # all_ones is false
            #
            if dets[count - 1] != curr_state and curr_state == 1 \
                                             and all_ones == False:

                # go through each value in the slice
                #
                while bckg_cnt < len(dets[count - 1:]):

                    # check if you have encountered a seiz event
                    #
                    if dets[bckg_cnt + count - 1] == 1:

                        # make sure the bckg event isn't a valid one
                        #
                        if bckg_cnt - count < self.min_bckg_frm:

                            # change each value in the bckg event to seiz
                            #
                            for x in range(bckg_cnt):

                                dets[count - 1 + x] = 1

                            # reset all flags
                            #
                            count = 0
                            bckg_cnt = 0
                            all_ones = True
                            
                            break

                        else:
                            
                            bckg_cnt += 1

                    else:

                        bckg_cnt += 1

                # reset all the values
                #
                count = 0
                bckg_cnt = 0
                all_ones = True

            else:

                count += 1

        # exit gracefully
        #
        return dets

    #
    # end of method
    
    # method: Postprocess::create_hyp
    #
    # arguments:
    #   data_list: a list containing the binary data to be output
    #   conf: the confidence of the event
    #   bin_detect: either 0 if the event is background or 1 if the
    #               event is seizure
    #
    # return:
    #   hyp: the final decision of the post processor in dictionary format
    #
    # This method will create and return a python dictionary or None if the
    # event is bckg
    #
    def create_hyp(self, data_list, conf, bin_detect):
        
        # check if the file has ended
        #
        if self.end_t_fin == None:

            # calculate the end time of the event
            #
            len_of_samp = len(data_list)
            end_t = (len_of_samp * (self.frmsize / self.samp_rate)) + self.st_t

        else:

            # set the end time to end at the file end
            #
            end_t = self.end_t_fin
       
        if bin_detect == 0:

            # set the start time for next event and the previous detection
            #
            self.previous_detection = 0
            self.st_t = end_t

            # return none if the event is background
            #
            return None
                    
        # create the python dictionary
        #
        hyp = {}
        hyp[DEF_TIMESTART] = format(self.st_t, DEF_OUTPUT_FORMAT)
        hyp[DEF_TIMEEND] = format(end_t, DEF_OUTPUT_FORMAT)
        hyp[DEF_LABEL] = SEIZ
        hyp[DEF_CONFIDENCE] = format(conf, DEF_OUTPUT_FORMAT)

        # set the start time for next event and the previous detection
        #
        self.st_t = end_t
        self.previous_detection = 1

        # return the dictionary 
        #
        return hyp
    #
    # end of method
#
# end of Postprocess

#
# end of file
