#!/usr/bin/env python
#
# file: NEDC_NFC/class/python/nedc_dpath_eb0_tools.py
#
# revision history:
#
# 20250921 (AM): adjusted for eb0
# 20250529 (DH): initial version
# 
# This file contains a Python implementation of the Efficientnet-B0 decoder and
# trainer
#------------------------------------------------------------------------------

# import system modules
#
import copy
import numpy as np
import os
import random
import sys
import time
import warnings
from pathlib import Path

# import NEDC modules
#
import nedc_debug_tools as ndt
import nedc_image_tools as nit
import nedc_dpath_ann_tools as nda
import nedc_file_tools as nft
import nedc_dpath_image_tools as ndi
import nedc_dpath_pproc_tools as dpt

# import torch modules in a deterministic manner
#
#
import torch
torch.manual_seed(ndt.RANDSEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(ndt.RANDSEED)
    torch.cuda.manual_seed_all(ndt.RANDSEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.mkldnn.enabled = False
 
# import rest of the necessary torch modules
#
import torch.nn as nn
from torch.optim import lr_scheduler
import torch.utils as utils
from torch.utils.data import Subset
from torchvision import models

# import polygon modules for patch ref label extraction
#
from shapely.geometry import Polygon, box
from shapely.ops import unary_union
from shapely.strtree import STRtree

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

# define the version of DPATH EfficientNet-B0
#
NEDC_DPATH_EB0_VERSION = "v1.0.0"

#------------------------------------------------------------------------------
# decoder variables are located here:
#
DEF_CLASSES = nda.DEF_CLASSES

#------------------------------------------------------------------------------
# trainer variables are located here:
#

# define values for common strings needed
#
TRAIN, DEV= nft.ML_TRAIN , nft.ML_DEV 
TRANSFORM = 'transform'

# define the learning rate of the stochastic gradient descent function
#
DEF_LEARNING_RATE = 0.001

# define the momentum of the stochastic gradient descent function
#
DEF_MOMENTUM = 0.900

# define a variable for after how many epochs to decrease learning rate
# of the stochastic gradient descent optimizer
#
DEF_STEP_SIZE = 7

# define a variable for how much to decrease the learning rate after
# STEP_SIZE amount of epochs
#
DEF_GAMMA = 0.100

# set a variable for infinity
#
DEF_WORST_LOSS_INIT = 'inf'

#------------------------------------------------------------------------------
# variables used by both the trainer and decoder are located here: 
#

# define a string to determine what device we are using
#
DEVICE_CPU = nft.DEVICE_TYPE_CPU 
DEVICE_GPU = nft.DEVICE_TYPE_GPU
DEVICE_CUDA = nft.DEVICE_TYPE_CUDA

# define keys for the dpath mapping file
#
DPATH_LABEL_MAP = nda.DPATH_LABEL_MAP
PARAM_KEY_LABEL_MAP = "label_map"
DPATH_PRIORITY_MAP = nda.DPATH_PRIORITY_MAP

# define keys for dpath model's class to index mapping
#
ATTR_MODEL_CLASS_TO_IDX_MAP = "class_to_idx"
ATTR_MODEL_STATE_DICT = "state_dict"
ATTR_MODEL_MODULE = "module."
ATTR_MODEL_DEV_TRANSFORMS = "dev_transforms"

# define formating strings for pretty printing confusion matrix
#
DELIM_CM_HEADER = "true\\pred"

# set the filename using basename
#
__FILE__ = os.path.basename(__file__)

# declare a global debug object so we can use it in functions
#
dbgl = ndt.Dbgl()

#------------------------------------------------------------------------------
#
# helper functions listed here
#
#------------------------------------------------------------------------------

def worker_init_fn(worker_id):
    """
    function: worker_init_fn
    
    arguments:
     worker_id: The unique ID of the worker

    return: None
    
    description:
     Generates and applies a deterministic seed for each
     worker to ensure reproducible results across different runs.
    """

    # create worker seed
    #
    worker_seed = worker_id + ndt.RANDSEED
    
    # Set random seeds
    #
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)
#
# end of function

def pretty_cm(cm, class_names, normalize=False):
    """
    function: pretty_cm

    arguments:
     cm: a 2-D confusion matrix (NumPy array-like) whose rows are true labels
      and columns are predicted labels
     class_names: an iterable of class names in the exact row/column order
     normalize: a boolean flag; if True, values are printed with 2 decimals,
      otherwise as integers

    return:
     table_str: a labeled, column-aligned string representation of the
      confusion matrix

    description:
     Formats a confusion matrix into a human-readable table. The first column
     lists the true labels; subsequent columns list predicted labels. This
     function does not modify the input matrix.
    """

    # convert class names to strings (safety for non-string labels)
    #
    class_names = [str(c) for c in class_names]

    # build the header row (true\pred, then predicted class names)
    #
    header = [DELIM_CM_HEADER] + class_names

    # choose a formatter based on normalization preference
    #
    if normalize:
        fmt = lambda x: f"{float(x):.2f}"
    else:
        fmt = lambda x: f"{int(x)}"

    # build body rows: [row_label] + formatted cells
    #
    body_rows = []
    for i, rlabel in enumerate(class_names):
        row = [rlabel] + [fmt(cm[i, j]) for j in range(len(class_names))]
        body_rows.append(row)

    # compute column widths from header and body for proper alignment
    #
    ncols = len(header)
    widths = [0] * ncols
    for c in range(ncols):
        widths[c] = \
            max(len(header[c]), *[len(row[c]) for row in body_rows])

    # create a row formatter with two spaces between columns
    #
    line_tmpl = nft.DELIM_SPACE.join("{:<" + str(w) + "}" for w in widths)

    # build the output table: header, separator, then rows
    #
    out_lines = [
        line_tmpl.format(*header),
        line_tmpl.format(*(nft.DELIM_DASH * w for w in widths))
    ]
    out_lines.extend(line_tmpl.format(*row) for row in body_rows)

    # exit gracefully
    #  return the assembled table string
    #
    return nft.DELIM_NEWLINE.join(out_lines)
#
# end of function

#------------------------------------------------------------------------------
#
# classes are listed here
#
#------------------------------------------------------------------------------

class DPATHEB0Decode:
    """
    class: DPATHEB0Decode

    arguments:
     none
    
    description:
     This class can decode img files using a given Efficientnet-B0 model 
     with weights
    """

    def __init__(self, window_size, frmsize, device, num_workers, num_threads,
                 mdl_path, mdl_weights_path, batch_size, threshold, pproc_alg):
        """
        method: constructor

        arguments:
         window_size: the window size
         frmsize: the frame size
         device: the device that the decoding will occur on
         num_workers: the number of workers to use for data loader
         num_threads: determine the number of threads to use (for cpu)
         mdl_path: the input model path (obj) optional
         mdl_weights_path: the model weights path
         batch_size: size of array input into Efficientnet-B0 model
         threshold: the frame level (intermediate) model output threshold
         pproc_alg: the dpath pre processing algorithm name
        
        returns:
         None

        description:
         this simple method is the constructor for the class
        """

        # set the class class
        #
        DPATHEB0Decode.__CLASS_NAME__ = self.__class__.__name__
        
        # display debug information
        #
        if dbgl > ndt.BRIEF:
            print("%s (line: %s) %s::%s: creating EB0 decode object" %
                  (__FILE__, ndt.__LINE__, DPATHEB0Decode.__CLASS_NAME__,
                   ndt.__NAME__))
            
        # get values we need to derive other variables
        #
        self.frmsize = frmsize
        
        # set the window size
        #
        self.window_size = window_size
        
        # set the device information
        #
        self.device = device

        # set the batch size
        #
        self.batch_size = batch_size

        # set the number of workers
        #
        self.num_workers = num_workers

        # set the frame level threshold
        #
        self.threshold = threshold
        
        # if the device that will be utilized is a cpu
        # limit the number of threads PyTorch can use
        #
        if self.device.type == DEVICE_CPU:
            torch.set_num_threads(num_threads)
            torch.set_num_interop_threads(1)
        
        # load the saved model data
        #
        bundle = torch.load(
            mdl_weights_path,
            map_location=DEVICE_CPU,
            weights_only=False
        )

        # get the index label map
        #
        class_to_idx = bundle[ATTR_MODEL_CLASS_TO_IDX_MAP]
        self.index_map = {v:k for k,v in class_to_idx.items()}

        # if 'bckg' is not among the available classes throw
        # error
        #
        if not nda.DEF_BCKG in class_to_idx: 
            print("Error: %s (line: %s) %s: bckg label not detected %s %s" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__,
                   "a class named 'bckg' (background)",
                   "must be included during training"))
            sys.exit(os.EX_SOFTWARE)
            
        # get the background index
        #
        self.bckg_idx = class_to_idx[nda.DEF_BCKG]
            
        # fetch dev transforms
        #
        self.transforms = bundle[ATTR_MODEL_DEV_TRANSFORMS]

        # instantiate default efficientnet-b0 model
        #
        model = models.efficientnet_b0(weights=None)

        # load an nn.module (pre-trained model) if given
        #
        if mdl_path is not None:
            state_dict = torch.load(mdl_path, weights_only = False)
            model = model.load_state_dict(state_dict)

        # append linear layer
        #
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(class_to_idx))

        # strip any DataParallel prefixes or noise from state dict
        #
        raw_state = bundle[ATTR_MODEL_STATE_DICT]
        clean = {}
        for k,v in raw_state.items():
            clean_key = k.replace(ATTR_MODEL_MODULE,nft.DELIM_NULL,1) \
                if k.startswith(ATTR_MODEL_MODULE) else k
            clean[clean_key] = v

        # load the models state dict
        #
        model.load_state_dict(clean)
        self.model = model.to(self.device)

        # set the preprocess algorithm
        #
        self.preprocessor = dpt.PreProcessor(
            alg =str(pproc_alg),
            window_size = self.window_size,
            frmsize = self.frmsize,
            transforms = self.transforms
        )

        return None
    #
    # end of method

    #--------------------------------------------------------------------------
    #
    # file-based methods
    #
    #--------------------------------------------------------------------------

    
    def decode(self, argfile, fp):
        """
        method: decode

        arguments:
         argfile: file name to process
         fp: pointer to log file
 
        return: boolean value indicating status
        
        description:
         This function decodes an img file and writes
         the output to a dpath annotation graph
        """

        # display debug information
        #
        if dbgl > ndt.BRIEF:
            print("%s (line: %s) %s::%s: decoding file (%s)" %
                  (__FILE__, ndt.__LINE__, DPATHEB0Decode.__CLASS_NAME__,
                   ndt.__NAME__, argfile))
        
        # make a dataset
        #
        self.dataset, coords_frms = self.preprocessor.get_image_dataset(argfile)
        
        # For decoding, shuffling is not required as we
        # are processing all data sequentially.
        # However, for consistency, we can still set a generator.
        #
        generator = torch.Generator()
        generator.manual_seed(ndt.RANDSEED)

        # create a dataloader for faster processing
        #
        data_loader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            generator=generator,             
            worker_init_fn=worker_init_fn,
            pin_memory=False,               
            persistent_workers=(self.num_workers > 0),
            prefetch_factor=(2 if self.num_workers > 0 else None)
        )
        
        # get detections based on the dataset
        #
        lbls, probs = self.get_detects(data_loader)

        # create and return annotation graph
        #
        return self.create_ann_graph(coords_frms, lbls, probs)
    #
    # end of method
    
    def create_ann_graph(self, coords, lbls, probs):
        """
        method: create_ann_graph
        
        arguments:
         coords: list of the prediction coordinates (frames) (ordered)
         lbls: list of the prediction labels (ordered)
         probs: the prediction probabilities (ordered)
         
        return:
         dpath_graph: a dpath annotation dictionary

        description:
         this method places our predictions into a dpath annotation to
         save to a file
        """

        # create dpath graph dict
        #
        dpath_graph = dict()
        
        # enumerate over the coordinates, labels, and probabilities
        # or confidences
        #
        for idx, ((fx, fy), lidx, prob) in enumerate(zip(coords, lbls, probs)):

            # region id can not be zero, increment by 1
            #
            region_id = idx + 1

            # fetch the class name 
            #
            class_name = self.index_map[lidx]

            # build a 256x256 square at top‐left
            #
            coords = [
                (fx, fy, 0),
                (fx + self.frmsize, fy, 0),
                (fx + self.frmsize, fy + self.frmsize, 0),
                    (fx, fy + self.frmsize, 0)
            ]
            
            # construct the annotation entry for the current region id
            #
            dpath_graph[idx] = {
                nda.CKEY_REGION_ID: region_id,
                nda.CKEY_TEXT: class_name,
                nda.CKEY_COORDINATES: coords,
                nda.CKEY_CONFIDENCE: prob,
                nda.CKEY_TISSUE_TYPE: nda.DEF_GRAPH_TISSUE,
                nda.CKEY_GEOM_PROPS: {
                    nda.CKEY_MICRON_LENGTH: 0.0,
                    nda.CKEY_MICRON_AREA: 0.0,
                    nda.CKEY_LENGTH: 0.0,
                    nda.CKEY_AREA: 0.0
                }
            }


        # exit gracefully
        # return dpath graph
        #
        return dpath_graph

    #
    # end of method

    #--------------------------------------------------------------------------
    #
    # processing-based methods
    #
    #--------------------------------------------------------------------------

    def get_detects(self, data_loader):
        """
        method: get_detects

        arguments:
         data_loader: DataLoader yielding batches of images (Tensor[B,C,H,W])

        return:
         lbls: predicted class indices per sample
         probs: predicted class probabilities per sample

        description:
         Runs the model in evaluation mode over the dataset and returns for each
         sample the top predicted class index and its probability
        """
        
        # switch to eval mode
        #
        self.model.eval()

        # prepare output containers
        #
        lbls = []
        probs = []

        # disable gradient computation
        #
        with torch.no_grad():

            # iterate over DataLoader batches
            #
            for batch in data_loader:

                # send inputs to device
                #
                batch = batch.to(
                    self.device, dtype=torch.float32, non_blocking=True
                )

                # forward pass 
                #
                logits = self.model(batch)

                # softmax probabilities
                #
                prob_mat = torch.softmax(logits, dim=1)

                # top class per sample
                #
                top_prob, top_idx = torch.max(prob_mat, dim=1)

                # background probability (used when falling back to background)
                #
                p_bkg = prob_mat[:, self.bckg_idx]

                # if the best class confidence is below threshold
                # force background
                #
                below_th = (top_prob <= self.threshold)
                final_idx = torch.where(
                    below_th,
                    torch.full_like(top_idx, self.bckg_idx),
                    top_idx)
                final_prob = torch.where(below_th, p_bkg, top_prob)

                # collect on CPU
                #
                lbls.extend(final_idx.cpu().tolist())
                probs.extend(final_prob.cpu().tolist())

        # return predictions and probabilities
        #
        return lbls, probs
    #
    # end of method

# 
# end of class

class DPATHEB0Train:
    """
    class: DPATHEB0Train

    arguments:
     none

    description:
     This class can train a pre trained Efficientnet-B0 Model
    """

    def __init__(self, winsize, frmsize, train_transforms, dev_transforms,
                 device, save_epoch_model, num_workers, num_threads, lbl_map,
                 num_gpus):
        """
        method: __init__

        arguments:
         window_size: the size of the window
         frmsize: the size of the frame
         train_transforms: the image transformations on the train dataset
         dev_transforms: the image transformations on the dev dataset
         device: the device that the training will be run on
         save_epoch_model: a boolean value to determine whether to save a model
          every epoch
         num_workers: number of workers
         num_threads: the number of threads to use for cpu training
         lbl_map: a mapping of source labels to one target variable to an
          integer code
         
        returns: none

        description:
         This method constructs an DPATHEB0 object.
        """

        # set the class name
        #
        DPATHEB0Train.__CLASS_NAME__ = self.__class__.__name__

        # display debug information
        #
        if dbgl > ndt.BRIEF:
            print("%s (line: %s) %s::%s: creating eb0 train object" %
                  (__FILE__, ndt.__LINE__, DPATHEB0Train.__CLASS_NAME__,
                   ndt.__NAME__))

        # save the label map
        #
        self.label_map = lbl_map
        
        # set the window size
        #
        self.window_size = winsize
        
        # set the frame size
        #
        self.frmsize = frmsize

        # create a variable to store the epoch model saving preferences
        #
        self.save_epoch_model = save_epoch_model
        
        # set the device training will be run on
        #
        self.device = device

        # store the number of gpus to use
        #
        self.num_gpus = num_gpus

        # store the number of workers
        #
        self.num_workers = num_workers

        # set the dev/train transforms
        #
        self.transforms = dict()
        self.transforms[TRAIN] = train_transforms
        self.transforms[DEV] = dev_transforms
        
        # create a variable to hold the datasets (i.e. train or dev)
        #
        self.data_sets = dict()

        # create a variable to hold the image datasets (i.e. train or dev)
        #
        self.image_sets = dict()

        # create a variable to store the number of samples within
        # a data set
        #
        self.num_samples = dict()

        # create a variable to hold the data weights
        #
        self.weights = dict()

        # create a variable to hold the torch data loaders
        #
        self.data_loaders = dict()

        # silence autocast warning
        #
        warnings.filterwarnings(
            "ignore",
            message="`torch.cuda.amp.autocast\\(args\\.\\.\\.\\)` is deprecated.*",
            category=FutureWarning,
        )
        
        # limit the number of threads/cores to num_workers if using
        # cpu
        #
        if self.device.type == DEVICE_CPU:
            torch.set_num_threads(num_threads)
            torch.set_num_interop_threads(1)


        # exit gracefully
        #
        return None
    #
    # end of method

    #--------------------------------------------------------------------------
    #
    # training set up based methods
    #
    #--------------------------------------------------------------------------

    def setup_train_env(self, train_img_list, train_ann_list, dev_img_list,
                        dev_ann_list, out_model_path, input_model_path,
                        input_weights_path, num_samps_train, num_samps_dev):
        """
        method: setup_train_env

        arguments:
         train_img_list: list of all training img files
         train_ann_list: list of all training ann files
         dev_img_list: list of all dev img files
         dev_ann_list: list of all dev ann files
         out_model_path: the output model file path
         input_model_path: optional initial input model to load data from
         input_weights_path: optional initial model weights path to load 
         num_samps_train: a dictionary of the maximum number of train samples
          per class
         num_samps_dev: a dictionary of the maximum number of dev samples per
          class
        
        return: boolean value indicating status

        description:
         This method has three functions: 1. generate and process the train/dev
         data sets, 2. calculates the weights for train/dev data sets, 3. loads
         model file path information for future use in the train method
        """

        # create a variable to store the output path of the final model
        #
        self.out_model_path = out_model_path

        # set the input model 
        #
        self.input_model_path = input_model_path

        # set the input weights path
        #
        self.input_weights_path = input_weights_path

        # build the training dataset
        #
        self.data_sets[TRAIN] = \
            self.build_dataset(train_img_list, train_ann_list)
                        
        # build the dev dataset
        #
        self.data_sets[DEV] = \
            self.build_dataset(dev_img_list, dev_ann_list)

        # create a common index map for TRAIN/DEV to share
        # and to save along side the best model
        #
        all_classes = sorted({s[2] for s in self.data_sets[TRAIN]} |
                             {s[2] for s in self.data_sets[DEV]})
        self.class_names = all_classes
        self.index_map = {c: i for i, c in enumerate(all_classes)}

        # ensure background label is included
        #
        if not nda.DEF_BCKG in self.class_names:
            print("Error: %s (line: %s) %s: bckg label not detected %s %s" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__,
                   "a class named 'bckg' (background)",
                   "must be included during training"))
            sys.exit(os.EX_SOFTWARE)
            
        # create the train image dataset
        #
        self.image_sets[TRAIN] = \
            ndi.TrainImages(
                self.data_sets[TRAIN],
                self.frmsize,
                self.window_size,
                index_map=self.index_map
            )
        
        # create the dev image dataset
        #
        self.image_sets[DEV] = \
            ndi.TrainImages(
                self.data_sets[DEV],
                self.frmsize,
                self.window_size,
                index_map=self.index_map
            )
        
        # select the number of class events/samples randomly 
        # for the train dataset
        #
        self.image_sets[TRAIN], self.num_samples[TRAIN] = \
            self.__select_samples(self.image_sets[TRAIN],
                                  num_samps_train)
        
        # select the number of class events/samples randomly
        # for the dev dataset
        #
        self.image_sets[DEV], self.num_samples[DEV] = \
            self.__select_samples(self.image_sets[DEV],
                                  num_samps_dev)

        # assign the train image transformations
        #
        if hasattr(self.image_sets[TRAIN], TRANSFORM):
            self.image_sets[TRAIN].transform = self.transforms[TRAIN]
        else:
            self.image_sets[TRAIN].dataset.transform = self.transforms[TRAIN]

        # assign the dev image transformations
        #
        if hasattr(self.image_sets[DEV], TRANSFORM):
            self.image_sets[DEV].transform = self.transforms[DEV]
        else:
            self.image_sets[DEV].dataset.transform = self.transforms[DEV]

        # Build tensors of per-class sample counts, ordered the same way
        # that CrossEntropyLoss expects (self.class_names)
        #
        train_counts = torch.tensor(
            [self.num_samples[TRAIN][str(cls)] for cls in self.class_names],
            dtype=torch.float32,
        )
        dev_counts = torch.tensor(
            [self.num_samples[DEV][str(cls)] for cls in self.class_names],
            dtype=torch.float32,
        )
        
        # Guard against division-by-zero in case a class is absent
        #
        train_counts = torch.clamp(train_counts, min=1.0)
        dev_counts = torch.clamp(dev_counts, min=1.0)
        
        # Inverse-frequency weighting
        #
        self.weights[TRAIN] = (train_counts.max() / train_counts)
        self.weights[DEV] = (dev_counts.max() / dev_counts)
        
        # normalise
        #
        self.weights[TRAIN] /= self.weights[TRAIN].sum()
        self.weights[DEV] /= self.weights[DEV].sum()

        # send weights to device
        #
        self.weights[TRAIN].to(self.device)
        self.weights[DEV].to(self.device)
        
        # display debugging information
        #
        if dbgl > ndt.BRIEF:
            
            print("%s (line: %s) %s::%s: EB0 training object loaded" %
                  (__FILE__, ndt.__LINE__, DPATHEB0Train.__CLASS_NAME__,
                   ndt.__NAME__))
            print("training weights = %s" % self.weights[TRAIN])
            print("development weights = %s" % self.weights[DEV])
        
        # exit gracefully
        #
        return True
    #
    # end of method

    def build_dataset(self, image_paths, ann_paths, overlap_thresh=float(0.3)):
        """
        method: build_dataset

        arguments:
         image_paths: paths to image files
         ann_paths: corresponding annotation file paths
         overlap_thresh: minimum fraction overlap to assign non-background label
        
        return:
         samples: list of tuples (img_path, (x, y), label)

        description:
         For each image/annotation pair, read header for dimensions, build a
         grid of non-overlapping patch coordinates, clean annotation polygons,
         and assign each patch a label based on maximum overlap or mark as
         background.
        """

        # initialize empty list to hold all samples
        #
        samples = []

        # loop over each image path and its corresponding annotation path
        #
        for img_path, ann_path in zip(image_paths, ann_paths):

            # create an annotation reader instance
            #
            ann = nda.AnnDpath()

            # load the annotation file
            #
            ann.load(ann_path)
            
            # get remapped ann graph
            #
            graph = nda.remap_labels(
                ann.get_graph(),
                self.label_map
            )
            
            # retrieve header info (including width & height)
            #
            header = ann.get_header()

            # parse image width from header
            #
            width = int(header[nda.CKEY_WIDTH])

            # parse image height from header
            #
            height = int(header[nda.CKEY_HEIGHT])
            
            # skip if there is more than one annotation
            #
            if len(graph) == 1:
                
                # get the single annotation entry
                #
                entry = next(iter(graph.values()))
                
                # build a tight bounding box around its coordinates
                #
                xs = [px for px, _, _ in entry[nda.CKEY_COORDINATES]]
                ys = [py for _, py, _ in entry[nda.CKEY_COORDINATES]]
                bbox_w = max(xs) - min(xs)
                bbox_h = max(ys) - min(ys)
                
                # if the bbox and the image match the frame size
                #
                if (bbox_w == self.frmsize and bbox_h == self.frmsize and
                    width == self.frmsize and height == self.frmsize):
                    
                    # use the annotation’s label
                    #
                    lbl = entry[nda.CKEY_TEXT]
                    
                    # append the sample and continue
                    #
                    samples.append((img_path, (0, 0), lbl))
                    continue

            # prepare list for patch grid coordinates
            #
            coords = []

            # iterate x from 0 to width-patch_size in steps of patch_size
            #
            for x in range(0, width - self.frmsize + 1, self.frmsize):

                # iterate y from 0 to height-patch_size in steps of patch_size
                #
                for y in range(0, height - self.frmsize + 1, self.frmsize):

                    # append this top-left coordinate to coords
                    #
                    coords.append((y, x))

            # prepare lists to hold cleaned polygons and their labels
            #
            polys, labels = [], []

            # iterate through each annotation graph entry
            #
            for _, v in graph.items():

                # extract raw polygon vertex coordinates
                #
                coords_poly = [(px, py) for px, py, _ in v[nda.CKEY_COORDINATES]]

                # build a Shapely polygon from these coords
                #
                raw_poly = Polygon(coords_poly)

                # buffer by zero to clean invalid geometries
                #
                fixed = raw_poly.buffer(0)

                # if the result has multiple parts
                #
                if hasattr(fixed, 'geoms'):

                    # collect all polygon parts
                    #
                    parts = [g for g in fixed.geoms if isinstance(g, Polygon)]

                    # if parts exist, pick the largest by area
                    #
                    if parts:
                        poly = max(parts, key=lambda p: p.area)
                    else:

                        # otherwise union and pick the largest piece
                        #
                        unioned = unary_union(fixed)
                        if isinstance(unioned, Polygon):
                            poly = unioned
                        else:
                            poly = max(unioned.geoms, key=lambda poly: poly.area)
                else:

                    # single-part geometry; use it directly
                    #
                    poly = fixed

                # store the cleaned polygon
                #
                polys.append(poly)

                # store the corresponding label text
                #
                labels.append(v[nda.CKEY_TEXT])

            # now assign a label to each patch based on polygon overlap
            #
            for x, y in coords:

                # create a patch bounding box
                #
                win_x = x + self.frmsize // 2 - self.window_size // 2
                win_y = y + self.frmsize // 2 - self.window_size // 2
                patch_box = box(
                    win_x, win_y,
                    win_x + self.window_size,
                    win_y + self.window_size
                )
                
                # initialize best overlap score
                #
                best_ov = float(0.0)

                # default label is background
                #
                best_lbl = nda.DEF_BCKG

                # check overlap with each cleaned polygon
                #
                for poly, lbl in zip(polys, labels):

                    # compute polygon area
                    #
                    patch_area = patch_box.area

                    # compute intersection area with patch
                    #
                    inter = patch_box.intersection(poly).area

                    # compute fraction of annotation covered by patch
                    #
                    patch_ov = inter / patch_area

                    # if this overlap is greater and the label is not excluded
                    #
                    if patch_ov > best_ov:

                        # update best overlap and label
                        #
                        best_ov = patch_ov
                        best_lbl = lbl

                # if the best overlap is below threshold, dont append
                #
                if best_ov < overlap_thresh:
                    continue
                
                # append the final (image, coord, label) tuple to samples
                #
                samples.append((img_path, (win_x, win_y), best_lbl))
                
        # return the full list of labeled samples
        #
        return samples
    #
    # end of method

    def __select_samples(self, image_set, nsamples_per_class):
        """
        method: __select_samples

        arguments:
         image_set: Dataset with classes (list of class names)
           and targets (list of ints)
         nsamples_per_class: dict mapping class name -> desired sample count
         
        return:
         subset: torch.utils.data.Subset of image_set
         nsamples_subdataset: dict mapping class name -> actual selected count

        description:
         Selects up to a specified number of samples per class from a dataset
         and logs the requested, available, and selected counts.
        """
        
        # get list of class names and convert targets to numpy
        #
        classes = list(image_set.classes)
        targets = np.array(image_set.targets)

        # compute how many are available per class
        #
        available = {cls: int((targets == idx).sum())
                     for idx, cls in enumerate(classes)}

        # determine how many to select per class (min(requested, available))
        #
        nsamples_subdataset = {}
        for cls in classes:
            req = nsamples_per_class.get(cls, 0)
            avail = available.get(cls, 0)
            if req == int(-1):
                sel = avail
            else:
                sel = min(req, avail)
            nsamples_subdataset[cls] = sel

        # collect indices for selected samples
        #
        selected_indices = []
        for idx, cls in enumerate(classes):
            count = nsamples_subdataset[cls]
            if count > 0:
                cls_idx = np.where(targets == idx)[0]
                chosen = np.random.choice(cls_idx, count, replace=False)
                selected_indices.extend(chosen.tolist())

        # exit gracefully
        # return the Subset and the per-class counts
        #
        return Subset(image_set, selected_indices), nsamples_subdataset
    #
    # end of method
    
    #--------------------------------------------------------------------------
    #
    # training methods
    #
    #--------------------------------------------------------------------------
     
    def train(self, batch_size, num_epochs, lr_rate, fp):
        """
        method: train

        arguments:
         batch_size: the training batch size
         num_epochs: the number of epochs for training
         lr_rate: user specified learning rate
         fp: file pointer to log file
        
        return:
         model: the trained model

        description
         this method is the main method of this class, this method will
         run the model training process num_epochs times
        """
        
        # display debugging information
        #
        if dbgl > ndt.BRIEF:
            
            print("%s (line: %s) %s::%s: preparing to train the model" %
                  (__FILE__, ndt.__LINE__, DPATHEB0Train.__CLASS_NAME__,
                   ndt.__NAME__))
            print("number of workers = %s" % self.num_workers)
            print("model file name = %s" % self.input_model_path)
            print("batch size = %s" % batch_size)
            print("number of epochs = %s" % num_epochs)
            print("save each epoch model = %s" % self.save_epoch_model)

        # keep the batch size for later
        #
        self.batch_size = batch_size
        
        # Set a random seed for the DataLoader shuffle
        #
        generator = torch.Generator()
        generator.manual_seed(ndt.RANDSEED)

        # create the train dataloader, this function creates an iterable
        # type that will allow us to iterate over the dataset
        #
        self.data_loaders[TRAIN] = torch.utils.data.DataLoader(
            self.image_sets[TRAIN], batch_size,
            shuffle=True,
            num_workers = self.num_workers,
            generator = generator,
            worker_init_fn=worker_init_fn,
            pin_memory=False,                
            persistent_workers=(self.num_workers > 0),
            prefetch_factor=(2 if self.num_workers > 0 else None)
        )
        
        # create the dev dataloader, this function creates an iterable
        # type that will allow us to iterate over the dataset
        #
        self.data_loaders[DEV] = torch.utils.data.DataLoader(
            self.image_sets[DEV], batch_size,
            shuffle=False,
            num_workers = self.num_workers,
            worker_init_fn=worker_init_fn,
            pin_memory=False,
            persistent_workers=(self.num_workers > 0),
            prefetch_factor=(2 if self.num_workers > 0 else None)
        )

        # instantiate Efficientnet-b0 model
        #
        model_ft = models.efficientnet_b0(weights=None)

        # load an nn.module (pre-trained model) if given
        #
        if self.input_model_path is not None:
            state_dict = \
                torch.load(self.input_model_path, weights_only = False,
                           map_location=torch.device('cpu'))
            model_ft.load_state_dict(state_dict)

        # append linear layer
        #
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs, len(self.class_names))
        model_ft = model_ft.to(self.device)

        # load the model weights path
        #
        if self.input_weights_path is not None:

            # load the raw state dict (weights)
            # and potentially additional data
            #
            bundle = torch.load(
                self.input_weights_path, weights_only = False
            )
            
            # check to find the correct state to load
            #
            if (isinstance(bundle, dict) and ATTR_MODEL_STATE_DICT in bundle):
                raw_state = bundle[ATTR_MODEL_STATE_DICT]
            else:
                raw_state = bundle
                
            # clean up state dict naming issues caused by dataparallel if
            # present
            #
            clean = {}
            for k,v in raw_state.items():
                clean_key = k.replace(ATTR_MODEL_MODULE,nft.DELIM_NULL) \
                    if k.startswith(ATTR_MODEL_MODULE) else k
                clean[clean_key] = v

            # load the state dict
            #
            model_ft.load_state_dict(clean)

        # multi-gpu training support
        #
        if (self.device.type == DEVICE_CUDA) and (self.num_gpus > 1):

            # fetch the number of available cuda gpus
            #
            available = torch.cuda.device_count()
            
            # clamp num_gpus depending on the amount actually
            # available
            #
            if available < self.num_gpus:
                self.num_gpus = available

            # instantiate a dataparallel model
            #
            model_ft = nn.DataParallel(
                model_ft,
                device_ids=list(range(self.num_gpus))
            )

        # send model to device
        #
        model_ft = model_ft.to(self.device)
    
        # define the cross entropy loss function for the dev and train datasets,
        # this will allow us to calculate the  error the model is generating,
        # and allow the model to adjust weights to find a better performance
        #
        train_criterion = \
            nn.CrossEntropyLoss(
                weight=self.weights[TRAIN].to(self.device),
            )

        dev_criterion = \
            nn.CrossEntropyLoss(
                weight=self.weights[DEV].to(self.device),
            )
        
        # observe that all parameters are being optimized by using a stochastic
        # gradient descent algorithm, which will allow us to find optimum points
        # of a function by testing values, lr (learning rate) is the distance
        # between each point tested, momentum will monitor the values so that
        # we are moving in the correct direction to find the optimum point
        #        
        optimizer_ft = \
            torch.optim.SGD(model_ft.parameters(),
                            lr = lr_rate,
                            momentum = DEF_MOMENTUM,
                            weight_decay=1e-3,
                            nesterov = True)
        
        # lr_scheduler.StepLR will allow us to adjust the learning rate of the
        # stochastic gradient descent algorithm by subtracting 0.1 (the gamma
        # value) every 7 epochs (the step_size)
        #        
        exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft,
                                               step_size = DEF_STEP_SIZE,
                                               gamma = DEF_GAMMA)

        # save loss function, optimizer, and lr scheduler names
        #
        loss_fn_name = type(train_criterion).__name__
        optimizer_name = type(optimizer_ft).__name__
        sched_name = type(exp_lr_scheduler).__name__

        # grab optimizer hyper-params from its defaults dict
        #
        opt_hparams = optimizer_ft.defaults.copy()
                
        # pull scheduler settings from its attributes
        #
        sched_hparams = {
            "step_size": exp_lr_scheduler.step_size,
            "gamma": exp_lr_scheduler.gamma,
        }

        # display debugging information
        #
        if dbgl > ndt.BRIEF:
            print("loss function = %s" % loss_fn_name)
            print("optimizer= %s" % optimizer_name)
            print("learning rate scheduler = %s" % sched_name)
            print("optimizer hyper parameters = %s" % opt_hparams)
            print("learning rate scheduler params = %s\n" % sched_hparams)

        # write debug info to log file
        #
        fp.write("loss function = %s" % loss_fn_name)
        fp.write("optimizer= %s" % optimizer_name)
        fp.write("learning rate scheduler = %s" % sched_name)
        fp.write("optimizer hyper parameters = %s" % opt_hparams)
        fp.write("learning rate scheduler params = %s\n" % sched_hparams)

        # train the model
        #
        model = \
            self.__train_model(model_ft, train_criterion, dev_criterion,
                               optimizer_ft, exp_lr_scheduler, fp, num_epochs)
        
        # exit gracefully
        #  return the trained model
        #
        return model
    #
    # end of method

    #--------------------------------------------------------------------------
    #
    # training methods
    #
    #--------------------------------------------------------------------------

    def __train_model(self, model, train_criterion, dev_criterion, optimizer,
                      scheduler, fp, num_epochs=16):
        """
        method: __train_model

        arguments:
         model: the neural network model which should be trained
         train_criterion: loss function for training (such as mse or
          cross entropy)
         dev_criterion: loss function for development (such as mse or
          cross entropy)
         optimizer: optimizer function (such as SGD or Adam)
         scheduler: scheduler object to change the optimizers parameters
         num_epochs: number of epochs which neural network will be trained
         fp: file pointer to log file
        
        return:
         model: the trained model

        description:
         This function accepts a model and trains it.
        """

        # display debugging information
        #
        if dbgl > ndt.BRIEF:
            print("%s (line: %s) %s::%s: training model" %
                  (__FILE__, ndt.__LINE__, DPATHEB0Train.__CLASS_NAME__,
                   ndt.__NAME__))

        # compute batch sizes for the train and dev dataloader
        #
        train_len = len(self.data_loaders[TRAIN]) * \
            self.data_loaders[TRAIN].batch_size
        
        dev_len = len(self.data_loaders[DEV]) * \
            self.data_loaders[DEV].batch_size

        # keep track of the start time
        #
        since = time.time()

        # store the state_dict using deepcopy
        #
        best_model_wts = copy.deepcopy(model.state_dict())

        # set the initial accuracy as 0.0
        #
        best_acc = 0.0

        # set the initial loss as infinity
        #
        best_loss = float(DEF_WORST_LOSS_INIT)
        
        # iterate through amount of epochs
        #
        for epoch in range(num_epochs):
            
            # keep track of the time for each epoch
            #
            train_start = time.time()

            # print out which epoch is currently running
            #
            print(f'Epoch {epoch+1}/{num_epochs}')

            # write results to log file
            #
            fp.write(f'Epoch {epoch+1}/{num_epochs}\n')
            
            # print out a line to make the output look nice
            #
            print(nft.DELIM_DASH * 10)

            # Each epoch has a training and validation phase
            # set the model to training mode for the training phase
            #
            model.train()

            # preform all necessary operations to the train
            # data set and fetch its loss and corrects
            #
            running_loss , running_corrects = \
                self.train_loop(optimizer, model, train_criterion)
            
            # execute the scheduler to update learning rate
            #
            scheduler.step()

            # calculate the loss and accuracy for this epoch
            #
            train_loss = running_loss / train_len
            train_acc = running_corrects.double() / train_len

            # save dev start time
            #
            dev_start = time.time()
            
            # print a message with the elapsed time, the loss, and
            # the accuracy for this epoch
            #
            print(f'Train \t Elapsed: {time.time()-train_start:.2f} sec '
                  + f'Loss: {train_loss:.4f} Acc: {train_acc:.4f}')

            # write results to log file
            #
            fp.write(f'Train \t Elapsed: {time.time()-train_start:.2f} sec '
                     + f'Loss: {train_loss:.4f} Acc: {train_acc:.4f}\n')

            # check if we want to save model parameters for each epoch
            #
            if self.save_epoch_model:

                # save the epoch model
                #
                self.save_epoch_model_m(model)
                
            # set the model to eval mode for the validation phase
            #
            model.eval()

            # preform all necessary operations to the train data
            # set and fetch its loss and corrects
            #
            running_loss, running_corrects = \
                self.dev_loop(model, dev_criterion)
           
            # calculate the loss and accuracy for this epoch
            #
            dev_loss = running_loss / dev_len
            dev_acc = running_corrects.double() / dev_len
            
            # print a message with the elapsed time, the loss, and
            # the accuracy for this epoch
            #
            print(f'Devel \t Elapsed: {time.time()-dev_start:.2f} sec '
                  + f'Loss: {dev_loss:.4f} Acc: {dev_acc:.4f}\n')

            # write results to log file
            #
            fp.write(f'Devel \t Elapsed: {time.time()-dev_start:.2f} sec '
                     + f'Loss: {dev_loss:.4f} Acc: {dev_acc:.4f}\n')
            
            # check if this epoch improved the model by seeing if this
            # epoch increased the dev acc
            #
            if dev_acc > best_acc:

                # set values for best accuracy and loss to this
                # epoch's values
                #
                best_acc = dev_acc
                best_loss = dev_loss

                # store the best model's weights by using deepcopy
                #
                best_model_wts = copy.deepcopy(model.state_dict())

        # calculate the total time
        #
        time_elapsed = time.time() - since

        # print a final message with total training time
        #
        print(f'Training completed in {time_elapsed // 60:.0f}m ' +
              f'{time_elapsed % 60:.0f}s')

        # write results to log file
        #
        fp.write(f'Training completed in {time_elapsed // 60:.0f}m ' +
                 f'{time_elapsed % 60:.0f}s\n')

        # load the best model weights
        #
        model.load_state_dict(best_model_wts)

        # exit gracefully
        #  return the best model
        #
        return model
    #
    # end of method

    def train_loop(self, optimizer, model, criterion):
        """
        method: train_loop
        
        arguments:
         optimizer: the chosen optimizer for training the model
         model: the model to train
         criterion: the loss function (cross entropy loss)

        return:
         running loss: the output of the loss function
         running corrects: the accuracy

        description:
         This function encompasses all train data set
         operations needed for and epoch of training in
         the train_model method
        """

        # set initial value for the running loss as 0
        #
        running_loss = 0.0

        # set initial value for running correct answers as 0
        #
        running_corrects = 0
        
        # iterate over the train dataloader
        #
        for inputs, labels in self.data_loaders[TRAIN]:
            
            # send the inputs and labels to the device
            #
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)
            
            # zero the parameter gradients
            #
            optimizer.zero_grad()
            
            # start the forward pass (flow of information
            # from the input to the output of the neural network)
            # set set_grad_enabled to true to allow torch to
            # calculate the gradients
            #
            with torch.set_grad_enabled(True):
                
                # send the inputs to the model
                #
                outputs = model(inputs)
                
                # get the maximum confidence
                #
                _, preds = torch.max(outputs, 1)
                
                # calculate the loss using the cross entropy loss function
                #
                loss = criterion(outputs, labels)
                
                # start the backward pass (adjusting model weights)
                #
                # use loss.backward() to compute gradients
                #
                loss.backward()
                
                # update the model parameters
                #
                optimizer.step()
                
            # calculate the statistics, running loss, and the running
            # corrects
            #
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        # exit gracefully
        #  return the running loss and corrects
        #
        return running_loss, running_corrects
    #
    # end of method

    def dev_loop(self, model, criterion):
        """
        method: dev_loop
        
        arguments:
         model: the model to develop
         criterion: the loss function (cross entropy loss)
        
        return:
         running loss: the output of the loss function
         running corrects: the accuracy

        description:
         This method encompasses all the dev data set operations
         needed for and epoch of training in the train_model method
        """

        # reset running loss
        #
        running_loss = 0.0
        
        # reset running corrects
        #
        running_corrects = 0
        
        # iterate over the dev dataloader
        #
        for inputs, labels in self.data_loaders[DEV]:
            
            # send the inputs and labels to the device
            #
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)
            
            # set grad enabled to false as we won't be changing gradients
            #
            with torch.set_grad_enabled(False):
                
                # send the inputs to the model
                #
                outputs = model(inputs)
                
                # get the maximum confidence
                #
                _, preds = torch.max(outputs, 1)
                
                # calculate the loss using the cross entropy loss function
                #
                loss = criterion(outputs, labels)
                    
            # calculate the statistics, running loss, and the running
            # corrects
            #
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        # exit gracefully
        #  return running loss and corrects
        #
        return running_loss, running_corrects
    #
    # end of method

    #--------------------------------------------------------------------------
    #
    # model-based methods
    #
    #--------------------------------------------------------------------------

    def save_epoch_model_m(self, model, epoch):
        """
        method: save_epoch_model_m

        arguments:
         model: model to save
         epoch: current epoch number

        return:
         boolean value indicating status

        description:
         This method saves an epoch model to the file
         specified by the out_model_path
        """

        # get the output model file path
        #
        file_path = os.path.dirname(self.out_model_path)

        # extract the file extension
        #
        file_ext = os.path.splitext(self.out_model_path)[1][1:]

        # get the output model basename
        #
        file_bname = \
            os.path.splitext(os.path.basename(self.out_model_path))[0]

        # append the basename to include the epoch value
        #
        file_bname = file_bname + nft.DELIM_USCORE + str(epoch + 1)

        # create the final model filename for this epoch
        #
        epoch_fname = nft.create_filename(file_bname, file_path,
                                          file_ext, None)

        # save epoch model
        #
        if isinstance(model, torch.nn.DataParallel):
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()
        torch.save(
            {
                ATTR_MODEL_STATE_DICT: state_dict,
                ATTR_MODEL_CLASS_TO_IDX_MAP: self.index_map,
                ATTR_MODEL_DEV_TRANSFORMS: self.transforms[TRAIN]
            },
            epoch_fname
        )

        # exit gracefully
        #
        return True
    #
    # end of method
    
    def read_stats(self, model, dataloader):
        """
        method: read_stats
        
        arguments:
         model: neural network model
         dataloader: the dataloaders dictionary

        return:
         all_labels: list of correct classifications (bckg or indc)
         all_preds: list of predicted classifications (bckg or indc)
         accuracy: the total classification accuracy

        description:
         This function will compute the values needed to compute the
         confusion matrix and accuracy
        """
 
        # display debugging information
        #
        if dbgl > ndt.BRIEF:
            print("%s (line: %s) %s::%s: fetching confusion matrix statistics" %
                  (__FILE__, ndt.__LINE__, DPATHEB0Train.__CLASS_NAME__,
                   ndt.__NAME__))

        # save the model state to restore it at the end
        #
        was_training = model.training
        model.eval()

        # create lists
        #
        all_labels = []
        all_preds = []

        # use torch.no_grad() to specify torch to not calculate the gradients
        #
        with torch.no_grad():

            # iterate through the dataloader
            #
            for i, (inputs, labels) in enumerate(dataloader):

                # send the inputs and labels to the device
                #
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                # send the inputs to the model
                #
                outputs = model(inputs)

                # get the maximum confidence
                #
                _, preds = torch.max(outputs, 1)
                 
                # update all_labels to have the correct labels
                #
                all_labels.extend(labels.cpu().numpy().tolist())
                                  
                # update all_preds to have the correct predicted values
                #
                all_preds.extend(preds.cpu().numpy().tolist())
                          
        # change the model state to last state
        #
        model.train(mode=was_training)

        # convert to labels/preds to numpy arrays
        #
        all_labels = np.asarray(all_labels, dtype=np.int64)
        all_preds  = np.asarray(all_preds,  dtype=np.int64)
        
        # computing accuracy based on simple mean absolute error
        #
        accuracy = np.mean(all_labels == all_preds)

        # exit gracefully
        #  return all necessary statistics
        #
        return all_labels, all_preds, accuracy
    #
    # end of method

#
# end of class

#
# end of file


