#!/usr/bin/env python
#
# file: $NEDC_NFC/src/python/util/nedc_driver/nedc_driver.py
#
# revision history:
#
# 20230422 (JP): reviewed for release
# 20230320 (ML): initial version
#
# This is the driver program for the real-time EEG decoder.
#------------------------------------------------------------------------------

# import system modules
#
import os
import sys
import json
import struct

# import nedc_modules:
#
import nedc_cmdl_parser as ncp
import nedc_debug_tools as ndt
import nedc_decoder
import nedc_file_tools as nft
import nedc_edf_tools as net

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

# set the filename using basename
#
__FILE__ = os.path.basename(__file__)

# define the location of the help files
#
HELP_FILE = "./driver.help"
USAGE_FILE = "./driver.usage"

# define default argument values
#
ARG_PARM = "--parameters"
ARG_ABRV_PARM = "-p"

ARG_BASE = "--basename"
ARG_ABRV_BASE = "-b"

ARG_ODIR = ncp.ARG_ODIR
ARG_ABRV_ODIR = ncp.ARG_ABRV_ODIR

ARG_RDIR = ncp.ARG_RDIR
ARG_ABRV_RDIR = ncp.ARG_ABRV_RDIR

# define parameter arguments
#
CHANNEL_ORDER = 'channel_order'
RESNET_DECODE = 'RESNET_DECODE'
MONTAGE_ORDER = 'montage_order'

# define default argument values
#
DEF_PARAMETER_FILE  = "$NEDC_NFC/lib/params_v01.txt"
DEF_ODIR = "$NEDC_NFC/test/output"

# no argument file identifier
#
DEF_NONE = "none"

# define constants for reading binary input
#
DEF_INP_SAMP_FREQ = 50
DEF_SHORT_INT_BSIZE = 2
DEF_SHORT_INT = 'h'
DEF_BINARY_EXT = 'bin'

# define dictionary constants
#
DEF_TIMESTART = "timeStart"
DEF_SAMPRATE = "sampleRate"
DEF_DATA = "data"

#------------------------------------------------------------------------------
#
# functions are listed here
#
#------------------------------------------------------------------------------

# declare global objects so we can use them in functions
#
decoder = nedc_decoder.Nedc_Decoder()
dbgl = ndt.Dbgl()

# function: main
#
def main(argv):

    # create a command line parser
    #
    cmdl = ncp.Cmdl(USAGE_FILE, HELP_FILE)
    cmdl.add_argument(ARG_ABRV_PARM, ARG_PARM, type = str)
    cmdl.add_argument("files", type = str, nargs = '*')
    cmdl.add_argument(ARG_ABRV_ODIR, ARG_ODIR, type = str)
    cmdl.add_argument(ARG_ABRV_RDIR, ARG_RDIR, type = str)
    cmdl.add_argument(ARG_ABRV_BASE, ARG_BASE, type = str)

    if dbgl > ndt.NONE:
        print("command line arguments:")
        print(f" output directory = {args.odir}")
        print(f" replace directory = {args.rdir}")
        print(f" parameter file = {args.parameters}")
        print(f" argument files = {args.files}")

    # parse the command line
    #
    args = cmdl.parse_args()
        
    pfile, odir, rdir, basename, arg_file = \
    (args.parameters, args.odir, args.rdir, args.basename, args.files[0])

    # get the parameter values
    #
    if pfile is None:
        pfile = nft.get_fullpath(DEF_PARAMETER_FILE)
    else:
        pfile = nft.get_fullpath(pfile)
    
    # check if the parameter file exists
    #
    if os.path.isfile(pfile) == False:
        print("Error: %s (line: %s) %s: parameter file doesn't exist (%s)" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__, pfile))
        sys.exit(os.EX_SOFTWARE)

    # set the output directory
    #
    if odir is None:
        odir = nft.get_fullpath(DEF_ODIR)
    else:
        odir = nft.get_fullpath(odir)

    # set the replace directory
    #
    if rdir is not None:
        rdir = nft.get_fullpath(rdir)

    # set default boolean values
    #
    arg_file_exists = True
    is_edf_file = False

    # check if the argument file is none
    #
    if arg_file.lower() == DEF_NONE:

        # set the boolean to indicate if we're decoding from a file
        #
        arg_file_exists = False

    # if the argument exists decode from the file
    #
    else:

        # set the argument filename
        #
        arg_file = nft.get_fullpath(arg_file)

        # check if the argument file exists
        #
        if os.path.isfile(arg_file) == False:
            print("Error: %s (line: %s) %s: input file doesn't exist (%s)" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__, arg_file))
            sys.exit(os.EX_SOFTWARE)
        
        # create an Edf object
        #
        edf = net.Edf()

        # check if it is a valid edf file
        #
        is_edf_file = edf.is_edf(arg_file)

    # load the parameter file
    #
    params = nft.load_parameters(pfile, RESNET_DECODE)

    # read the parameter file
    #
    # decoder parameters
    #
    ch_order = eval(params[CHANNEL_ORDER])
    mon_order = eval(params[MONTAGE_ORDER])
    
    # initialize the apply montage class
    #
    apply_mont = net.ApplyMontage(ch_order, mon_order)
    
    # initialize the dictionary that will hold the data
    #
    buff = {}
    for channel in ch_order:
        buff[channel] = []

    # initialize the decoder
    #
    decoder.init(pfile, basename, odir, rdir)

    # create the output binary data file
    #
    bin_filename = nft.create_filename(basename, odir, DEF_BINARY_EXT, rdir)

    # open the binary file for writing
    #
    bin_file = open(bin_filename, nft.MODE_WRITE_BINARY)
    
    # define how many bytes needed per channel
    #
    total_bytes_needed = DEF_INP_SAMP_FREQ * DEF_SHORT_INT_BSIZE

    # create the format for how many short ints are needed
    #
    def_form = nft.DELIM_NULL.join([str(DEF_INP_SAMP_FREQ), DEF_SHORT_INT])

    # set a counter to keep track of how much data we have processed
    #
    timestart = 0
    
    # check if we are decoding from stdin
    #
    if not arg_file_exists:

        # read binary data from stdin
        #
        while True:
            
            # iterate through all channels
            #
            for channel in ch_order:
                
                # read the necessary amount of bytes for this channel
                #
                bin_dat = os.read(0, total_bytes_needed)
                
                # check if there is no more data in stdin
                #
                if len(bin_dat) == 0:

                    # flush the decoder and reset its state:
                    #  this call outputs any valid seizure hypotheses that are
                    #  active but have not completed.
                    #
                    decoder.flush()

                    # exit gracefully
                    #
                    return True

                # make sure we have the correct amount of data
                #
                if len(bin_dat) != total_bytes_needed:
                    print("Error: %s (line: %s) %s: error reading binary data" \
                          " from stdin" % (__FILE__, ndt.__LINE__, ndt.__NAME__))
                    sys.exit(os.EX_SOFTWARE)

                # write the data to the binary file
                #
                bin_file.write(bin_dat)
    
                # add the data to buff
                #
                buff[channel] = list(struct.unpack(def_form, bin_dat))

            # check if we need to apply a montage
            #
            if mon_order is not None:

                # apply the montage
                #
                buff = apply_mont.apply_montage(buff)
                
            # create a dictionary to hold all of the data
            #
            sig_data = {}
            sig_data[DEF_TIMESTART] = timestart
            sig_data[DEF_SAMPRATE] = DEF_INP_SAMP_FREQ
            sig_data[DEF_DATA] = buff

            # reset buff
            #
            buff = {}
            for channel in ch_order:
                buff[channel] = []
        
            # send the data to the decoder
            #
            decoder.process(sig_data)

            # increment the counter
            #
            timestart += 1 

    # if we are decoding from a file
    #
    else:

        # if the file is an edf file
        #
        if is_edf_file:

            # get the sample rate and the signals
            #
            header, sigs = edf.read_edf(arg_file, False)

            # get the sample frequency
            #
            samprate = edf.get_sample_frequency(0)

            # check if arg_file is sampled at 50 Hz
            #
            if samprate != DEF_INP_SAMP_FREQ:
                print("Error: %s (line: %s) %s: file must be sampled at 50 Hz" \
                      " (%s)" % (__FILE__, ndt.__LINE__, ndt.__NAME__, arg_file))
                sys.exit(os.EX_SOFTWARE)

            # iterate through 1 second of data at a time
            #
            for ind in range(0, len(sigs[ch_order[0]]), DEF_INP_SAMP_FREQ):

                # iterate through each channel
                #
                for chan in ch_order:

                    # slice the data from sigs
                    #
                    orig_sigs = [int(x) for x in
                                 sigs[chan][ind: ind + DEF_INP_SAMP_FREQ]]

                    # add the data to buff
                    #
                    buff[chan] = orig_sigs
                    
                    # write the signal to binary file
                    #
                    data = struct.pack(def_form, *orig_sigs)
                    bin_file.write(data)
                        
                # check if we need to apply a montage
                #
                if mon_order is not None:

                    # apply the montage
                    #
                    buff = apply_mont.apply_montage(buff)
                
                # create a dictionary to hold all of the data
                #
                sig_data = {}
                sig_data[DEF_TIMESTART] = timestart
                sig_data[DEF_SAMPRATE] = DEF_INP_SAMP_FREQ
                sig_data[DEF_DATA] = buff

                # reset buff
                #
                buff = {}
                for channel in ch_order:
                    buff[channel] = []
        
                # send the data to the decoder
                #
                decoder.process(sig_data)

                # increment the counter
                #
                timestart += 1 

        # if it's not an edf file assume it is a binary file
        #
        else:

            # make sure we can open the binary file
            #
            try:
                
                # read the binary file
                #
                fp = open(arg_file, nft.MODE_READ_BINARY)

            except:
                print("Error: %s (line: %s) %s: error opening binary file (%s)" %
                      (__FILE__, ndt.__LINE__, ndt.__NAME__, arg_file))
                sys.exit(os.EX_SOFTWARE)
                
            # create infinite loop
            #
            while True:

                # go through each channel
                #
                for chan in ch_order:

                    # read a second of data
                    #
                    data = fp.read(total_bytes_needed)

                    # check if the file is empty
                    #
                    if len(data) == 0:

                        # flush the decoder and reset its state:
                        #  this call outputs any valid seizure hypotheses that
                        #  are active but have not completed.
                        #
                        decoder.flush()

                        # exit gracefully
                        #
                        return True

                    # make sure we have the correct amount of data
                    #
                    if len(data) != total_bytes_needed:
                        print("Error: %s (line: %s) %s: error reading binary " \
                              "data from file (%s)" % (__FILE__, ndt.__LINE__,
                                                       ndt.__NAME__, arg_file))
                        sys.exit(os.EX_SOFTWARE)

                    # write the data to the binary file
                    #
                    bin_file.write(data)
                        
                    # unpack the binary data
                    #
                    buff[chan] = list(struct.unpack(def_form, data))

                # check if we need to apply a montage
                #
                if mon_order is not None:

                    # apply the montage
                    #
                    buff = apply_mont.apply_montage(buff)
                
                # create a dictionary to hold all of the data
                #
                sig_data = {}
                sig_data[DEF_TIMESTART] = timestart
                sig_data[DEF_SAMPRATE] = DEF_INP_SAMP_FREQ
                sig_data[DEF_DATA] = buff

                # reset buff
                #
                buff = {}
                for channel in ch_order:
                    buff[channel] = []
        
                # send the data to the decoder
                #
                decoder.process(sig_data)

                # increment the counter
                #
                timestart += 1 
                    
# begin gracefully
#
if __name__ == '__main__':
    main(sys.argv[0:])

#
# end of file
