#!/usr/bin/env python
#
# file: $NEDC_NFC/util/python/nedc_dpath_post_proc/nedc_dpath_post_proc.py

# revision history:
#
# 20250126 (DH): initial version
#
# This script can call post-processing operations on decoded dpath annotations
#------------------------------------------------------------------------------

# import system modules except torch
#
import atexit
import os
from pathlib import Path
import random
import shutil
import sys
import tempfile
import time

# import NEDC modules
#
import nedc_cmdl_parser as ncp
import nedc_debug_tools as ndt
import nedc_file_tools as nft 
import nedc_dpath_ann_tools as nda
import nedc_dpath_pproc_tools as dpt

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

# set the filename using basename
#
__FILE__ = os.path.basename(__file__)

# define the location of the help file
#
HELP_FILE = \
    "$NEDC_NFC/docs/help/nedc_dpath_post_proc.help"

# define the location of the usage file
#
USAGE_FILE = \
     "$NEDC_NFC/docs/usage/nedc_dpath_post_proc.usage"

# define the location of the default parameter file
#
DEF_PARAM = \
    "$NEDC_NFC/docs/params/nedc_dpath_post_proc_params_v00.toml"

# define the command line argument options:
#
ARG_ODIR = ncp.ARG_ODIR
ARG_ABRV_ODIR = ncp.ARG_ABRV_ODIR

ARG_RDIR = ncp.ARG_RDIR
ARG_ABRV_RDIR = ncp.ARG_ABRV_RDIR

ARG_OEXT = ncp.ARG_OEXT
ARG_ABRV_OEXT = ncp.ARG_ABRV_OEXT

ARG_PARM = "--pfile"
ARG_ABRV_PARM = "-p"

# define keys to access parameters
#
DPATH_PPROC = "DPATH_PPROC"
PPROC_PARAM_KEY_ALG = "algorithm"
PPROC_PARAM_KEY_PRIORITY_MAP = "priority_map"
DPATH_PRIORITY_MAP = nda.DPATH_PRIORITY_MAP
PPROC_PARAM_KEY_THRESHOLD = "threshold"

# define the default output directory
#
DEF_ODIR = "./output"

# define the default output file extension
#
DEF_OEXIT = nft.DEF_EXT_CSV

# declare a global debug object so we can use it in functions
#
dbgl = ndt.Dbgl()

#------------------------------------------------------------------------------
#
# functions are listed here
#
#------------------------------------------------------------------------------

def apply_threshold(patch_graph, threshold):
    """
    function: apply_threshold

    arguments:
     patch_graph: the output graph of post processor
     threshold: the threshold to apply to the patch graph

    return: reference to changed graph (same as reference before)

    description:
     apply threshold to the patch graph confidence field
    """

    # fetch the patch graphs region key/value pairs
    #
    regions = patch_graph.items()

    # iterate over each region key index and values
    #
    for rid, region in regions:

        # fetch the region label and probability/confidence
        #
        lbl, prob = region[nda.CKEY_TEXT], region[nda.CKEY_CONFIDENCE]

        # apply confidence threshold to region
        # set label to bckg if below threshold
        #
        if prob < threshold and lbl != nda.DEF_BCKG:
            patch_graph[rid][lbl] = nda.DEF_BCKG

    # exit gracefully
    #  return patch_graph reference
    #
    return patch_graph
#
# end of method

def main(argv):
    """
    function: main
    
    arguments: command line
    
    return: boolean value indicating status
    
    description:
     This function is the main function
    """

    # create a command line parser
    #
    cmdl = ncp.Cmdl(USAGE_FILE, HELP_FILE)
    
    # define the command line arguments
    #
    cmdl.add_argument("files", type = str, nargs = "*")
    cmdl.add_argument(ARG_ABRV_ODIR, ARG_ODIR, type = str)
    cmdl.add_argument(ARG_ABRV_RDIR, ARG_RDIR, type = str)
    cmdl.add_argument(ARG_ABRV_OEXT, ARG_OEXT, type = str) 
    cmdl.add_argument(ARG_ABRV_PARM, ARG_PARM, type = str)
    
    # parse the command line
    #
    args = cmdl.parse_args()

    # make sure there is a file argument
    #
    if len(args.files) == 0:
        cmdl.print_usage()
        sys.exit(os.EX_SOFTWARE)
        
    # process the command line arguments
    #
    if args.ext is None:
        args.ext = DEF_OEXIT
    if args.rdir is None:
        args.rdir = ncp.DEF_RDIR
    if args.odir is None:
        args.odir = DEF_ODIR
  
    # display debug information
    #
    if dbgl > ndt.NONE:
        print("command line arguments:")
        print(f" file extention = {args.ext}")
        print(f" output directory = {args.odir}")
        print(f" replace directory = {args.rdir}")
        print(f" frmsize = {args.frmsize}")
        print(f" argument files = {args.files}")

    # fetch parameters from command line
    #
    pfile, odir, rdir, oext = \
        (args.pfile, args.odir, args.rdir, args.ext)

    # load the parameter file
    #
    params = nft.load_parameters(pfile, DPATH_PPROC)

    # fetch the post-processing algorithm name
    #
    alg = params[PPROC_PARAM_KEY_ALG]

    # fetch the post-processing threshold
    #
    thresh = params[PPROC_PARAM_KEY_THRESHOLD]
    
    # make sure the montage file is specified
    #
    if alg is None:
        print("Error: %s (line: %s) %s: %s" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__,
               "must specify a post-processing algorithm"))
        sys.exit(os.EX_SOFTWARE)

    # fetch label priority path
    #
    label_pmap_fname = params[PPROC_PARAM_KEY_PRIORITY_MAP]

    # make sure the priority map file is specified
    #
    if label_pmap_fname is None:
        print("Error: %s (line: %s) %s: must specify a priority map" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__))
        sys.exit(os.EX_SOFTWARE)

    # get the model files full path
    #
    label_pmap_fpath = nft.get_fullpath(label_pmap_fname)

    # ensure priority map file exists
    #
    if os.path.isfile(label_pmap_fpath) == False:
        print("Error: %s (line: %s) %s: priority map doesn't exist (%s)" %
              (__FILE__, ndt.__LINE__, ndt.__NAME__, label_pmap_fpath))
        sys.exit(os.EX_SOFTWARE)

    # fetch the priority map
    #
    label_pmap = nft.load_parameters(label_pmap_fpath, DPATH_PRIORITY_MAP)
    
    # create elapsed variable to hold the elapsed time
    #
    elapsed = float(0.0)

    # keep track of the start time
    #
    init_time = time.time()

    # main processing loop: loop over all input filenames
    #
    num_files_att = int(0)
    num_files_proc = int(0)

    # create ann tool
    #
    ann = nda.AnnDpath()

    # create pproc object
    #
    pproc = dpt.PostProcessor(alg, label_pmap)
    
    # make the output directory
    #
    os.makedirs(args.odir, exist_ok=True)
    
    # process each file argument
    #
    for arg_file in args.files:

        # fetch arg_files full path
        #
        ffile = nft.get_fullpath(arg_file)

        # check if the argument file exists
        #
        if os.path.isfile(ffile) == False:
            print("Error: %s (line: %s) %s: file doesn't exist (%s)" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__, ffile))
            sys.exit(os.EX_SOFTWARE)

        # check if the file is a list file or Image file
        #
        if nft.is_ann(ffile):

            # display debug information
            #
            if dbgl > ndt.NONE:
                print("%s (line: %s) %s: decoding image file (%s)" %
                      (__FILE__, ndt.__LINE__, ndt.__NAME__, ffile))

            # display information
            #
            num_files_att += int(1)
            print("%3ld: %s" % (num_files_att, ffile))

            # load annotation graph into post processor
            # 
            if pproc.load(ffile) == False:
                print("Error: %s (line: %s) %s: %s (%s)" %
                      (__FILE__, ndt.__LINE__, ndt.__NAME__,
                       "error loading into preprocessor", ffile))

            # run post processor, fetch results
            #
            patch_graph = pproc.predict_regions()

            # run ptach graph through a confidence threshold 
            #
            patch_graph = apply_threshold(patch_graph, thresh)

            # load file in dpath ann to get image width/height
            #
            ann.load(ffile)
            hdr = ann.get_header()
            
            # if ann graph is not none continue
            #
            if patch_graph is not None and hdr is not None:
                
                # create output filename
                #
                ofile = nft.create_filename(ffile, odir, oext, rdir)
                
                # create header and save annotation to csv file
                #
                nda.write_data_to_file(
                    patch_graph, hdr[nda.CKEY_WIDTH],
                    hdr[nda.CKEY_HEIGHT], ofile
                )
                num_files_proc += int(1)
            else:
                print("Error: %s (line: %s) %s: %s (%s)" %
                      (__FILE__, ndt.__LINE__, ndt.__NAME__,
                       "output error", ffile))

        # if the file isn't an image file assume it is a file list
        #
        else:

            # display debug information
            #
            if dbgl > ndt.NONE:
                print("%s (line: %s) %s: opening list (%s)" %
                      (__FILE__, ndt.__LINE__, ndt.__NAME__, ffile))
                
            # fetch file list
            #
            flist = nft.get_flist(ffile)
            
            # ensure the flist method worked
            #
            if flist is None:
                print("Error: %s (line: %s) %s: %s (%s)"
                      % (__FILE__, ndt.__LINE__, ndt.__NAME__,
                         "error retrieving file list", flist))
                sys.exit(os.EX_SOFTWARE)
            
            # expand environment variables of each file in flist and
            # ensure it exists
            #
            for fname in flist:

                # fetch the full path
                #
                ffile = nft.get_fullpath(fname)
                
                # check if the argument file exists
                #
                if os.path.isfile(ffile) == False:
                    print("Error: %s (line: %s) %s: file doesn't exist (%s)" %
                          (__FILE__, ndt.__LINE__, ndt.__NAME__, ffile))
                    sys.exit(os.EX_SOFTWARE)
                    
                # display debug information
                #
                if dbgl > ndt.NONE:
                    print("%s (line: %s) %s: decoding image file list (%s)" %
                          (__FILE__, ndt.__LINE__, ndt.__NAME__, ffile))

                # display information
                #
                num_files_att += int(1)
                print("%3ld: %s" % (num_files_att, ffile))
                
                # load annotation graph into post processor
                #
                if pproc.load(ffile) == False:
                    print("Error: %s (line: %s) %s: %s (%s)" %
                          (__FILE__, ndt.__LINE__, ndt.__NAME__,
                           "error loading into preprocessor", ffile))
                
                # run post processor, fetch results
                #
                patch_graph = pproc.predict_regions()
                
                # run ptach graph through a confidence threshold
                #
                patch_graph = apply_threshold(patch_graph, thresh)

                # load file in dpath ann to get image width/height
                #
                ann.load(ffile)
                hdr = ann.get_header()

                # if ann graph is not none continue
                #
                if patch_graph is not None and hdr is not None:

                    # create output filename
                    #
                    ofile = nft.create_filename(ffile, odir, oext, rdir)

                    # create header and save annotation to csv file
                    #
                    nda.write_data_to_file(
                        patch_graph,  hdr[nda.CKEY_WIDTH],
                        hdr[nda.CKEY_HEIGHT], ofile
                    )

                    num_files_proc += int(1)

                else:
                    print("Error: %s (line: %s) %s: %s (%s)" %
                          (__FILE__, ndt.__LINE__, ndt.__NAME__,
                           "output error", ffile))

    # calculate the elapsed time
    #
    elapsed = time.time() - init_time

    # display the results
    #
    print("processed %ld out of %ld files successfully" %
          (num_files_proc, num_files_att))
    
    # display the finished message
    #
    print(f'decoding of all files finished in {elapsed:.0f} seconds.')
    
    # exit gracefully
    #
    return True
#
# end of function

# begin gracefully
#
if __name__ == '__main__':
    main(sys.argv)
#
# end of file
