#!/usr/bin/env python
#
# file: $NEDC_NFC/class/python/nedc_eeg_eval_tools/nedc_eeg_eval_common.py
#
# revision history:
#
# 20251014 (JP): moved *_table to file tools
# 20240716 (JP): refactored after the rewrite of ann_tools
# 20230623 (AB): refactored code to new comment style
# 20220514 (JP): refactored the code to use the new annotation tools library
# 20200813 (LV): added parse_files method
# 20200622 (LV): first version
#
# Usage:
#  import nedc_eeg_eval_common as nec
#
# This file contains a collection of functions and variables commonly used
# across EEG evaluation tools.
#------------------------------------------------------------------------------

# import system modules
#
import os
from operator import itemgetter
from pathlib import Path
import sys

# import nedc_modules
#
import nedc_eeg_ann_tools as neat
import nedc_debug_tools as ndt
import nedc_file_tools as nft

#------------------------------------------------------------------------------
#                                                                              
# global variables are listed here                                             
#                                                                              
#------------------------------------------------------------------------------

# set the filename using basename                                             
#                                                                           
__FILE__ = os.path.basename(__file__)

# define a constant used to indicate a null choice  
#                                                                              
NULL_CLASS = "***"

# define constant that appears in the parameter file                          
#                                                                              
PARAM_MAP = "MAP"

# define standard delimiters for ROC/DET curves
#
DELIM_ROC = "ROC_CURVE"
DELIM_DET = "DET_CURVE"

#------------------------------------------------------------------------------
#
# functions are listed here
#
#------------------------------------------------------------------------------

# declare a global debug object so we can use it in functions
#
dbgl = ndt.Dbgl()

def format_hyp(ref, hyp):
    """
    function: format_hyp
    
    arguments:
     ref: the references events as a list
     hyp: the hypothesis events as a list
 
    return: 
     refo: a string displaying the alignment of the reference
     hypo: a string displaying the alignment of the hypothesis
     hits: the number of correct
     subs: the number of substitution errors
     inss: the number of insertion errors
     dels: the number of deletion errors

    description: 
     This function displays all the results in output report.
    """
       
    # declare return values
    #
    hits = int(0)
    subs = int(0)
    inss = int(0)
    dels = int(0)
    
    # find the max label length and increment by 1
    #
    maxl = int(0)
    for lbl in ref:
        if len(lbl) > maxl:
            maxl = len(lbl)
    for lbl in hyp:
        if len(lbl) > maxl:
            maxl = len(lbl)
    maxl += 1

    # loop over the input: skip the first and last label
    #
    refo = nft.STRING_EMPTY
    hypo = nft.STRING_EMPTY

    for i in range(1, len(ref)-1):

        # save a copy of the input
        #
        lbl_r = ref[i]
        lbl_h = hyp[i]

        # count the errors
        #
        if (ref[i] == NULL_CLASS) and (hyp[i] != NULL_CLASS):
            inss += int(1)
            lbl_h = hyp[i].upper()
        elif (ref[i] != NULL_CLASS) and (hyp[i] == NULL_CLASS):
            dels += int(1)
            lbl_r = ref[i].upper()
        elif (ref[i] != hyp[i]):
            subs += int(1)
            lbl_r = ref[i].upper()
            lbl_h = hyp[i].upper()
        else:
            hits += int(1)
            
        # append the strings
        #
        refo += ("%*s " % (maxl, lbl_r))
        hypo += ("%*s " % (maxl, lbl_h))

    # exit gracefully
    #
    return (refo, hypo, hits, subs, inss, dels)
#
# end of function

def parse_files(files, scmap = None):
    """
    function: parse_files
    
    arguments:
     reflist: list of hypothesis or reference files
     scmap: a scoring map used to augment the annotations
 
    return: 
     odict: dictionary with unique filename sequence as key and list of
     corresponding annotations as the values 

    description: 
     This function parses each file in a list of reference files into a dictionary
     format with file names as keys. The dictionary is of the format:
     '0000258.csv': [[0.0, 24.0, {'bckg': 1}], [24.0, 151.0, {'seiz': 1}]
     The dictionary key must be the fileanme because it is used to display
     the filename in scoring.

     A recent addition to this method was code to fill in gaps in the
     annotations with "bckg" and to collapse multiple consecutive hypotheses.
    """
       
    # display informational message                                           
    #                                                                     
    if dbgl > ndt.BRIEF:
        print("%s (line: %s) %s: parsing files" % 
              (__FILE__, ndt.__LINE__, ndt.__NAME__))
        
    # declare local variables
    #
    ann = neat.AnnEeg()
    odict = {}
    
    # load annotations
    #
    for i in range(len(files)):
        
        if ann.load(files[i]) == False:
            print("Error: %s (line: %s) %s: %s (%s)" %
                  (__FILE__, ndt.__LINE__,
                   ndt.__NAME__, "error loading references",
                   files[i]))
            return False

        # get the file duration after stripping off the units
        #
        cdict = nft.extract_comments(files[i])
        duration = float(cdict[neat.CSV_KEY_DURATION].split()[0])

        # get the events
        #
        events = ann.get()

        # if there are annotations, sort them based on start time
        # else: add one background event that spans the entire file
        #
        if events is not False:
            events_sorted = sorted(events, key=itemgetter(0))
        else:
            events_sorted = []
            events_sorted.append([float(0.0), duration,
                                  {neat.DEF_BCKG: neat.PROBABILITY}])

        # augment the annotation with background intervals
        #
        events_new = neat.augment_annotation(events_sorted, duration, neat.DEF_BCKG)
        
        # reduce multiple background events
        #
        events_reduced = neat.remove_repeated_events(events_new)
        
        # store full file path because this is used to print out
        # the filenames being processed
        #
        fname = Path(files[i])

        # create dictionary with filename as key
        #
        if fname not in odict:
            odict[fname] = events_reduced
        else:
            # add event to the list
            #
            odict[fname].append(events_reduced)

    # return a dictionary
    #
    return odict
    
#
# end of function

# end of file
#
