#!/usr/bin/env python
#
# file: $NEDC_NFC/util/python/nedc_imld/v5.0.1/app/backend/nedc_imld_tools.py
#
# revision history:
#
# 20251003 (SA): refactored code to match IMLD standards
#
# This class enscapulsates functions that call on ML Tools for the IMLD app.
#-------------------------------------------------------------------------------

#-------------------------------------------------------------------------------
#
# Section 1: Imports
#
#-------------------------------------------------------------------------------

# import system modules
#
from contextlib import redirect_stdout
from io import StringIO, BytesIO
from math import ceil, floor, sqrt
import numpy as np
import pickle
from scipy.spatial import KDTree

# import nedc modules
#
from nedc_file_tools import load_parameters
import nedc_ml_tools as mlt
import nedc_ml_tools_data as mltd

#-------------------------------------------------------------------------------
#
# Section 2: Global Constants
#
#-------------------------------------------------------------------------------

# define keys for model dictionary access
#
KEY_FILE_NAME = 'fname'
KEY_FILE_POINTER = 'fp'
KEY_MAPPING_LABEL = 'mapping_label'

#-------------------------------------------------------------------------------
#
# Section 3: Module-level Functions
#
#-------------------------------------------------------------------------------

def check_return(func, *args, **kwargs):
    """
    function: check_return

    args:
     func (function): the function to call
     *args (list)   : the arguments to pass to the function
     **kwargs (dict): the keyword arguments to pass to the function

    return:
     res (any): the result of the function call

    description:
     This wrapper function is used to check the return value of a function.
     This is primarily used with ML Tools, since ML Tools does not return
     an exception. ML Tools simply returns None and prints to the console.
     This function will grab the out of the function. If it is None, then
     this will use a std.out redirect to capture the output of the function.
     Then, an exception will be raised with that message as the body. If the
     return value is valid, return the return value of the function. Should
     only really be used when calling ML Tools functions, i.e. model.predict(),
     model.train(), model.score()
    """

    # create a string buffer to capture the std output of the function
    #
    capture = StringIO()

    # call the function and capture its std output
    #
    with redirect_stdout(capture):
        res = func(*args, **kwargs)
    
    # determine if the cpature has a substring that indicates an error
    #
    if 'Error:' in capture.getvalue():
        raise MLToolsError(capture.getvalue().strip())

    # exit gracefully
    #
    return res

def create_model(alg_name:str, params=None) -> mlt.Alg:
    """
    function: create_model

    args:
     alg_name (str): the name of the algorithm to use in the model
     params (dict) : a dictionary containing the parameters for the
                      algorithm. see ML Tools line 135 for an example.
                      [optional]

    return:
     mlt.Alg: the ML Tools object that was created

    description:
     Create and configure the MLTools algorithm object. 
    """

    # create an instance of a ML Tools algorithm
    #
    model = mlt.Alg()

    # set the type of algorithm to use based on the name given
    #
    if model.set(alg_name) is False:
        return None

    # if algorithm parameters are given
    #
    if params is not None:

        # set the algorithm's parameters
        #
        if model.set_parameters(params) is False:
            return None
        
    # exit gracefully
    #
    return model

def create_data(x:list, y:list, labels:list) -> mltd.MLToolsData:
    """
    function: create_data

    args:
     x (np.ndarray) : the data to use in the ML Tools data object
     y (np.ndarray) : the labels to use in the ML Tools data object
     labels (np.ndarray): the labels to use in the ML Tools data object

    return:
     mltd.MLToolsData: the ML Tools data object created

    description:
     Combine x and y into a single dataset, create MLToolsData object with
     labels.
    """

    # create a numpy array from the data
    # make sure to stack the x and y data into a single array
    # ex: x = [1,2,3]
    #     y = [4,5,6]
    #     X = [[1,4],
    #          [2,5],
    #          [3,6]]
    #
    X = np.column_stack((x, y))

    # set the data and labels in the ML Tools data object
    #
    data = mltd.MLToolsData.from_data(X, labels)

    # exit gracefully
    #
    return data

def normalize_data(x:list, y:list, xrange:list, yrange:list):
    """
    function: normalize_data

    args:
     x (list): list of x-values to normalize
     y (list): list of y-values to normalize
     xrange (list): target range [min, max] for x-values
     yrange (list): target range [min, max] for y-values

    return:
     x.list: list representing normalized x values
     y.list: list representing normalized y values

    description:
     Normalize x and y data from [-1, 1] to target ranges.
    """

    # convert x and y to NumPy arrays for math
    #
    x = np.array(x)
    y = np.array(y)

    # get the x and y bounds of the data
    #
    x_min, x_max = xrange
    y_min, y_max = yrange

    # normalize the data to the range mins and maxes
    # base the normalization on the assumption that all
    # data is generated for the range [-1, 1]
    #
    x = (x - (-1)) / (1 - (-1)) * (x_max - x_min) + x_min
    y = (y - (-1)) / (1 - (-1)) * (y_max - y_min) + y_min

    # exit gracefully
    #
    return x.tolist(), y.tolist()

def denormalize_data(x: list, y: list, xrange: list, yrange: list):
    """
    function: denormalize_data

    args:
     x (list): list of normalized x-values
     y (list): list of normalized y-values
     xrange (list): original range [min, max] for x-values
     yrange (list): original range [min, max] for y-values

    return:
     x.list: list representing denormalized x values
     y.list: list representing denormalized y values

    description:
     Transform normalized data back to its original scale.
    """

    # convert x and y to NumPy arrays for math
    #
    x = np.array(x)
    y = np.array(y)

    # get the x and y bounds of the original data
    #
    x_min, x_max = xrange
    y_min, y_max = yrange

    # denormalize the data from the range [-1, 1]
    #
    x = ((x - x_min) / (x_max - x_min)) * (1 - (-1)) + (-1)
    y = ((y - y_min) / (y_max - y_min)) * (1 - (-1)) + (-1)

    # exit gracefully
    #
    return x.tolist(), y.tolist()

def generate_data(dist_name:str, params:dict):
    """
    function: generate_data

    arguments:
     dist_name (str): name of the distribution to generate data from
     params (dict): the parameters for the distribution

    return:
     mltd.MLToolsData: a MLToolsData object populated with the data from the
                       distribution
     X (list): the data generated from the distribution
     y (list): the labels generated from the distribution
    
    description:
     Generate a MLToolsData object given a distribution name and the parameters.
    """

    if dist_name == 'gaussian':

        # group keys by their numeric suffix
        # 
        grouped_data = {}

        # iterate through the parameters and group them by their numeric suffix
        #
        for key, value in params.items():

            # extract the base name and the number
            #
            prefix, num = key[:-1], key[-1]

            # convert the number to an integer
            #
            num = int(num)

            # if the number is not in the grouped data, create a new dictionary
            #
            if num not in grouped_data:
                grouped_data[num] = {}

            if prefix == 'mean':
                value = list(np.array(value).flatten())

            # assign value to corresponding group
            #
            grouped_data[num][prefix] = value

        # create a list of the grouped values
        #
        params = list(grouped_data.values())

    # create a ML Tools data object using the class method
    #
    data_obj = mltd.MLToolsData.generate_data(dist_name, params)
    data = data_obj.data

    # get the data and labels from the ML Tools data object
    #
    labels = data_obj.labels.tolist()
    x = data[:, 0].tolist()
    y = data[:, 1].tolist()

    # exit gracefully
    #
    return labels, x, y

def train(model:mlt.Alg, data:mltd.MLToolsData):
    """
    function: train

    args:
     model (mlt.Alg)        : the ML Tools algorithm to train
     data (mltd.MLToolsData): the data to train the model on

    return:
     model (mlt.Alg)        : the trained model
     stats (dict)           : a dictionary of covariance, means and priors
     score (float)          : f1 score

     description:
      Train a ML Tools model on a given set of data. The data must be in the
      MLToolData class. Return the trained model, a goodness of fit score, a
      the labels generated while calculating the goodness of fit score.    
    """

    # train the model
    #
    check_return(model.train, data)

    # get the performance metrics of the model on the test data
    #
    metrics, parameter_output = predict(model, data)

    # exit gracefully
    #
    return model, metrics, parameter_output

def predict(model:mlt.Alg, data:mltd.MLToolsData):
    """
    function: predict

    args:
     model (mlt.Alg)        : the ML Tools trained model to use for predictions
     data (mltd.MLToolsData): the data to predict

    return:
     metrics (dict): a dictionary of the performance metrics of the model,
                     including:
                         - confusion matrix
                         - sensitivity
                         - specificity
                         - precision
                         - accuracy
                         - error rate
                         - F1 score

    description:
     Use a ML Tools trained model to predict unseen data. Return vectors
     of the labels given to each index of the unseen data, and posterior
     probabilities of each class assignment for each index of the array.
    """

    # predict the labels of the data
    #
    hyp_labels, _ = check_return(model.predict, data)

    # get the performance metrics of the model
    #
    metrics = score(model, data, hyp_labels)

    # get the parameter outcomes
    # check the return because it is a ML Tools function
    #
    parameter_output = check_return(model.get_info)
    
    # exit gracefully
    #
    return metrics, parameter_output

def score(model:mlt.Alg, data:mltd.MLToolsData, hyp_labels:list):
    """
    function: score

    args:
     num_classes (int): the number of classes
     data (mltd.MLToolsData): the input data including reference labels
     hyp_labels (list): the hypothesis labels

    return: (dict) a dictionary containing the following metrics:
        conf_matrix (list): the confusion matrix
        sens (float): the sensitivity
        spec (float): the specificity
        prec (float): the precision
        acc (float): the accuracy
        err (float): the error rate
        f1 (float): the F1 score
    
    description:
     Calculate various metrics to that show how well a model performed on
     unseen data. Pass it unseen data with the proper labels, the hypothesis
     labels, and the number of classes. Return the performance metrics of the
     model
    """

    # get the number of classes from the data
    # the number of classes is always the greatest amount of
    # labels in the hyp or ref data. this is done to ensure
    # that there are no issues when scoring
    #
    if (data.num_of_classes > len(set(hyp_labels))): 
        num_classes = data.num_of_classes
    else:
        num_classes = len(set(hyp_labels))

    # convert the hypothesis labels to a numpy array of ints
    #
    hyp_labels = np.array(hyp_labels, dtype=int)

    # map the labels to the proper format
    #
    hyp_labels = data.map_label(hyp_labels)

    # score the model
    #
    conf_matrix, sens, spec, prec, acc, err, f1 = (
        model.score(num_classes, data, hyp_labels)
    )    

    # exit gracefully
    #
    return {
        'Confusion Matrix': conf_matrix.tolist(),
        'Sensitivity': sens * 100,
        'Specificity': spec * 100,
        'Precision': prec * 100,
        'Accuracy': acc * 100,
        'Error Rate': err,
        'F1 Score': f1 * 100
    }

def load_params(pfile:str) -> dict:
    """
    function: load_params

    args:
     pfile (str): the file to load the parameters from

    return:
     params (dict): a dictionary of the parameters

    description:
     Load the parameters from a file and return them as a dictionary.
    """

    # load the algorithm parameters from the file
    #
    algs = load_parameters(pfile, "LIST")

    # get the parameters for each algorithm
    #
    params = {}
    for alg in algs:
        params[alg] = load_parameters(pfile, alg) 

    # exit gracefully
    #
    return params

def factor_pair(n: int):
    """
    function: factor_pair
    args:
     n (int): the number to factor
    return:
     (int, int): a tuple containing the two factors of n
    description:
     Find a pair of factors of n that are as close to each other as possible.
     This is useful for creating a grid of points that is as square as possible.
     This is a helper function for generating decision surfaces for QSVM models
     as for 100x100 grid, QSVM needs 10000 points which can be computationally
     expensive. 
    """

    if n <= 0:
        raise ValueError("n must be positive")
    root = int(sqrt(n))
    for a in range(root, 0, -1):
        if n % a == 0:
            b = n // a
            return a, b  
    return 1, n  

def find_boundary(xx_i, yy_i, z_i, nx_i, ny_i):
    """
    function: find_boundary

    args: 
     xx_i (np.ndarray): the x values of the initial meshgrid
     yy_i (np.ndarray): the y values of the initial meshgrid
     z_i (np.ndarray): the predicted labels of the initial meshgrid
     nx_i (int): the number of points in the x direction of the initial meshgrid
     ny_i (int): the number of points in the y direction of the initial meshgrid

    return: 
     np.ndarray: an array of (x, y) coordinates, representing the midpoints of 
                 the decision boundary. 
    
    description: 
     Identifies the approximate location of the decision boundary in a
     lower-resolution meshgrid. It checks for changes in the predicted class
     label between adjacent points in the grid and calculates midpoint
     coordinates where the class transition occurs.
    """

    # initialize a list to store the boundary points
    #
    boundary_points = []

    # check for class changes between points vertically adjacent
    #
    diff_v = z_i[:-1, :] != z_i[1:, :]
    for i, j in np.argwhere(diff_v):

        # if the class changes between points vertically adjacent, add the
        # midpoint of the decision boundary to the list of boundary points
        #
        boundary_points.append((xx_i[i, j], (yy_i[i, j] + yy_i[i+1, j]) / 2))

    # check for class changes between points horizontally adjacent
    #
    diff_h = z_i[:, :-1] != z_i[:, 1:]
    for i, j in np.argwhere(diff_h):

        # if the class changes between points horizontally adjacent, add the
        # midpoint of the decision boundary to the list of boundary points
        #
        boundary_points.append(((xx_i[i, j] + xx_i[i, j+1]) / 2, yy_i[i, j]))
        
    # exit gracefully
    #
    return np.array(boundary_points)

def generate_decision_surface(model:mlt.Alg,
                              xrange:list=[-1, 1], 
                              yrange:list=[-1, 1],
                              classes:list = None,
                              grid_size:int = 100):
    """
    function: generate_decision_surface

    args:
     model (mlt.Alg): the trained model to use to generate the decision surface
     xrange (list)  : the x range of the data to use to generate the decision 
                      surface. default = [-1, 1]
     yrange (list)  : the y range of the data to use to generate the decision
                      surface. default = [-1, 1]
     classes (list) : the list of classes in the data. this is required to
                      create the MLToolsData object for the meshgrid

    return:
     x (list) : the x values of the decision surface
     y (list) : the y values of the decision surface
     z (list) : the z values of the decision surface

    description:
     Generate the decision surface of a model given a set of data. Generate the
     decision surface by finding the x and y bounds of the data, then create a
     meshgrid of the data (a grid of points within the bounds). Then use the
     model to predict the classification at each point in the meshgrid. Return
     the x, y, and z (class) values of the decision surface.

     Function first predicts a 100x100 grid to identify the location of the 
     decision boundary. It then uses a KDTree to select only the points within a
     small buffer around the boundary and performs a targeted prediction at 
     these points, overwriting the inital course data to efficiently generate a
     smoother decision surface. 
    """

    # define initial (fast) and final meshgrid sizes
    #
    nx_i, ny_i = 100, 100
    nx, ny = 500, 500

    # get the x and y bounds of the data
    #
    x_min, x_max = xrange
    y_min, y_max = yrange
    
    # define buffer size around initial boundary
    #
    bf_ratio = 0.05
    bf_size = bf_ratio * (x_max - x_min)

    # create a meshgrid of the data. use this meshgrid to predict the labels
    # at each point in the grid. this will allow us to plot the decision surface
    # xx and yy will be the x and y values of the grid in the form of 2D arrays.
    # xx acts as the rows of the grid, and yy acts as the columns of the grid
    #
    if model.alg_d.model_d[mlt.ALG_NAME_ALG] == mlt.QSVM_NAME:

        # for QSVM, limit the number of points in the grid to the grid_size
        # parameter to avoid excessive computation. Find a factor pair of
        # grid_size to create a grid that is as square as possible
        #
        nx_i, ny_i = factor_pair(grid_size)
        nx, ny = factor_pair(grid_size)

    # define evenly space coordinates that make up initial grid points
    #
    x_i = np.linspace(x_min, x_max, nx_i)
    y_i = np.linspace(y_min, y_max, ny_i)

    # create the initial, faster meshgrid
    #
    xx_i, yy_i = np.meshgrid(x_i, y_i)
    
    # combine the xx and yy arrays to create a 3D array of the grid. this will
    # effectively create a list of all the points in the grid. the shape of the
    # array will be (n, 2)
    #
    XX_i = np.c_[xx_i.ravel(), yy_i.ravel()]

    # create an MLToolsData object from the grid. since MLTools. predict needs
    # MLToolsData as an input, we need to create a MLToolsData object from the
    # grid. this is the inital, coarser meshgrid. 
    #
    meshgrid_i = mltd.MLToolsData.from_data(XX_i, np.array(classes))

    # get the predictions of the model on each point of the meshgrid 
    # get the labels for each point in the meshgrid. the labels will be
    # flattened for each sample in the meshgrid. this is for the initial
    # meshgrid. 
    #
    labels_i, _ = check_return(model.predict, meshgrid_i)
    labels_i = np.asarray(labels_i)

    # reshape the labels to be the same shape as the xx and yy arrays
    #
    z_i = labels_i.reshape(xx_i.shape)

    # get the boundary points of the initial meshgrid
    #
    boundary_points = find_boundary(xx_i, yy_i, z_i, nx_i, ny_i)

    # if there are no boundary points, return the initial meshgrid
    #
    if boundary_points.size == 0: 

        # set x, y, and z to be the initial meshgrid values
        #
        x = x_i
        y = y_i
        z = z_i

        # exit gracefully
        #
        return x, y, z

    # define evenly space coordinates that make up final grid points
    #
    x = np.linspace(x_min, x_max, nx)
    y = np.linspace(y_min, y_max, ny)

    # create final (finer) meshgrid structure
    #
    xx, yy = np.meshgrid(x, y)

    # combine the xx and yy arrays to create a 3D array of the final grid
    #
    XX = np.c_[xx.ravel(), yy.ravel()]

    # initialize z to hold result
    #
    z = np.empty_like(xx, dtype = labels_i.dtype)

    # initialize z with coarse (less fine) predictions to quickly fill plot
    #
    x_indices = np.clip(np.digitize(xx, x_i, right=True) - 1, 0, nx_i - 1)
    y_indices = np.clip(np.digitize(yy, y_i, right=True) - 1, 0, ny_i - 1)

    # map every point in the final grid to nearest prediction in source array
    #     
    z = z_i[y_indices, x_indices]

    # make KDTree of boundary points
    #
    boundary_tree = KDTree(boundary_points)

    # find indices that fall without buffer size
    #
    indices = boundary_tree.query_ball_point(XX, r = bf_size)

    # get unique indices that produced a non-empty match
    #
    boundary_indices = np.array(
        [i for i, sublist in enumerate(indices) if sublist]
    )

    if boundary_indices.size == 0: 
        
        # if no points fall within buffer, skip slow prediction, exit gracefully
        #
        return x, y, z

    # extract points for targeted finer prediction
    #
    XX_t = XX[boundary_indices, :]

    # predict only at targeted points
    #
    meshgrid_t = mltd.MLToolsData.from_data(XX_t, np.array(classes))
    
    labels_t, _= check_return(model.predict, meshgrid_t)

    # overwrite z array with more accurate targeted predictions
    #
    z_flat = z.ravel()
    z_flat[boundary_indices] = np.asarray(labels_t)
    z = z_flat.reshape(z.shape)

    # return the x, y, and z values of the decision surface. 
    # x and y should be a 1D array, so get a row from the xx array and
    # a column from the yy array.
    #
    return x, y, z

def save_model(model:mlt.Alg, mapping_label:dict):
    """
    function: save_model

    args:
     model (mlt.Alg) : the ML Tools model to save
     mapping_label (dict): the mapping labels for the model

    return: None

    description:
     Save a ML Tools model to a file.
    """

    # serialize the model using pickle and store it in a BytesIO stream
    #
    model_bytes = BytesIO()

    save_model_args = {
        KEY_FILE_NAME: '',
        KEY_FILE_POINTER: model_bytes
    }

    # save the mapping labels to the model
    #
    model.alg_d.model_d[KEY_MAPPING_LABEL] = mapping_label

    # use the ML Tools save_model method by passing in the BytesIO stream
    # as the fp. check the return because it is a ML Tools function
    #
    check_return(model.save_model, **save_model_args)
    
    # move the cursor to the beginning of the BytesIO stream
    # this is required to send file as a response
    #
    model_bytes.seek(0)

    # exit gracefully
    #
    return model_bytes

def load_model(model_bytes:bytes):
    """
    function: load_model

    args:
     model_bytes (BytesIO): the BytesIO stream containing the model

    return:
     model (mlt.Alg)      : the loaded ML Tools model
     mapping_label (dict) : the mapping labels for the model

    description:
     Load a ML Tools model from a file.
    """

    # load the model from the BytesIO stream
    #
    model = mlt.Alg()

    # load the model from the BytesIO stream
    # check its return value because it is a ML Tools function
    #
    check_return(model.load_model, **{KEY_FILE_POINTER: BytesIO(model_bytes)})

    # if a mapping label is in the model, get it
    # else, set it to None to be processed with the DS
    #
    if KEY_MAPPING_LABEL in model.alg_d.model_d: 
        mapping_label = model.alg_d.model_d[KEY_MAPPING_LABEL]
    else:
        mapping_label = None

    # exit gracefully
    #
    return model, mapping_label

#-------------------------------------------------------------------------------
#
# Section 4: MLToolsError Class
#
#-------------------------------------------------------------------------------

# create instance for ML Tools Error
#
class MLToolsError(Exception):
    """
    class: MLToolsError

    description:
     Create an instance for ML Tools Error, raised when an error occurs in ML
     Tools, storing a descriptive error message.
    """

    def __init__(self, message):
        """
        method: constructor

        args:
         message: description of error

        return: None

        description:
         Initializes the MLToolsError instance with a given message.
        """

        self.message = message

    #
    # end of method

    def __str__(self):
        """
        method: __str__

        args: None

        return:
         str: the stored error message

        description:
         Returns the error message as a string.
        """

        return self.message

    #
    # end of method

#
# end of class
    
#
# end of file
