#!/usr/bin/env python
#
# file: $NEDC_NFC/util/python/nedc_imld/v5.0.1/app/extensions/blueprint.py
#
# revision history:
#
# 20250925 (SA): refactored code to fit ISIP standards
#
# Defines the Flask Blueprint and API functions for IMLD, including model
# caching, data generation, and algorithm parameter handling. 
#-------------------------------------------------------------------------------

# import system modules
#
from collections import OrderedDict
from datetime import datetime
from flask import Blueprint, render_template, request, jsonify, current_app, \
                  send_file
import json
import numpy as np
import os
import pickle
import subprocess
import toml

# import nedc modules
#
import nedc_debug_tools as ndt
import nedc_imld_tools as imld
import nedc_ml_tools as mlt

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

# set the filename using basename
#
__FILE__ = os.path.basename(__file__)

# params files name
#
ALGO_PARAMS_FILE = 'algo_params_v00.toml'
DATA_PARAMS_FILE = 'data_params_v00.toml'

# template and download filenames
#
TEMPLATE_INDEX = 'index.shtml'
DOWNLOAD_ALG_TOML = 'alg.toml'
DOWNLOAD_MODEL_PKL = 'model.pkl'

# config keys
#
CONFIG_KEY_BACKEND = 'BACKEND'

# cache dictionary keys
#
CACHE_KEY_MODEL = 'model'
CACHE_KEY_TIMESTAMP = 'timestamp'

# cache expiration time (seconds)
#
CACHE_EXPIRE_TIME = 300

# json/post request and response keys
#
JSON_KEY_DATA = 'data'
JSON_KEY_NAME = 'name'
JSON_KEY_ALG_NAME = 'algoName'
JSON_KEY_PARAMS = 'params'
JSON_KEY_USER_ID = 'userID'
JSON_KEY_XRANGE = 'xrange'
JSON_KEY_YRANGE = 'yrange'
JSON_KEY_PLOT_DATA = 'plotData'
JSON_KEY_METHOD = 'method'
JSON_KEY_ALGO = 'algo'
JSON_KEY_MODEL = 'model'
JSON_KEY_OLD_BOUNDS = 'oldBounds'
JSON_KEY_BOUNDS = 'bounds'
JSON_KEY_LABEL_MAPPINGS = 'label_mappings'
JSON_KEY_LABELS = 'labels'
JSON_KEY_X = 'x'
JSON_KEY_Y = 'y'
JSON_KEY_Z = 'z'
JSON_KEY_DECISION_SURFACE = 'decision_surface'
JSON_KEY_MAPPING_LABEL = 'mapping_label'
JSON_KEY_ERROR = 'error'
JSON_KEY_METRICS = 'metrics'
JSON_KEY_PARAMETER_OUTPUT = 'parameter_output'
JSON_KEY_ISSUE_NUM = 'issueNum'
JSON_KEY_TYPE='type'
JSON_KEY_DEFAULT='default'
JSON_KEY_FILE='file'

# MIME types
#
MIME_JSON = 'application/json'
MIME_OCTET = 'application/octet-stream'

# normalization methods
#
NORMALIZE_METHOD = 'normalize'
DENORMALIZE_METHOD = 'denormalize'
RENORMALIZE_METHOD = 'renormalize'

# default issue number
#
DEFAULT_ISSUE_NUM = '000'
DEFAULT_FILE_TYPE = 'utf-8'

# environment variable for issue number
#
ENV_ISSUE_NUM = 'IMLD_ISSUE_NUM'

# declare a global debug object so we can use it in functions
#
dbgl = ndt.Dbgl()

# create a Blueprint
#
main = Blueprint('main', __name__)

# create a global variable to hold the models
#
model_cache = {}

#------------------------------------------------------------------------------
#
# functions are listed here
#
#------------------------------------------------------------------------------

def clean_cache():
    """
    function: clean_cache

    arguments: None

    return:
     a Boolean value indicating status

    description:
     Iterates through the model cache and removes any cached models that are
     older than 5 minutes, based on their timestamp.
    """

    # get the current time
    #
    now = datetime.now()

    # iterate through the model cache and remove any models that are older than
    # 5 minutes
    #
    for key in list(model_cache.keys()):
        if (now - model_cache[key][CACHE_KEY_TIMESTAMP]).seconds > CACHE_EXPIRE_TIME:
            del model_cache[key]

    # exit gracefully
    #
    return True

@main.route('/')
def index():
    """
    function: index

    arguments: None

    return:
     HTML rendered page

    description:
     Route handler for the root URL. Renders the main index page.
    """

    # exit gracefully: render and return the main index page
    #
    return render_template(TEMPLATE_INDEX)

@main.route('/api/get_alg_params/', methods=['GET'])
def get_alg_params():
    """
    function: get_alg_params

    arguments: None

    return:
     JSON response containing algorithm parameters

    description:
     Loads algorithm parameters from a TOML file and returns them
     as an ordered JSON response. Used by the frontend to configure models.
    """

    # get the default parameter file. do not do this as a global variable
    # because the 'current_app.config' object only works in a route
    #
    pfile = os.path.join(current_app.config[CONFIG_KEY_BACKEND], ALGO_PARAMS_FILE)

    # load the algorithm parameters from the file
    #
    params = imld.load_params(pfile)

    # exit gracefully: return data as JSON
    #
    return current_app.response_class(
        json.dumps(OrderedDict(params)), # manually serialize ordered data
        mimetype=MIME_JSON
    )

@main.route('/api/get_data_params/', methods=['GET'])
def get_data_params():
    """
    function: get_data_params

    arguments: None

    return:
     JSON response containing data generation parameters

    description:
     Loads data generation parameters from a TOML file and returns them as an 
     ordered JSON response. Used by the frontend for dataset configuration.
    """

    # get the default parameter file. do not do this as a global variable
    # because the 'current_app.config' object only works in a route
    #
    pfile = os.path.join(current_app.config[CONFIG_KEY_BACKEND], DATA_PARAMS_FILE)

    # load the algorithm parameters from the file
    #
    params = imld.load_params(pfile)

    # exit gracefully: return data as JSON
    #
    return current_app.response_class(
        json.dumps(OrderedDict(params)),  # serialize ordered data to JSON
        mimetype=MIME_JSON
    )

@main.route('/api/load_alg_params/', methods=['POST'])
def load_alg_params():
    """
    function: load_alg_params

    arguments:
     None (input comes from POST request)

    return:
     JSON response containing algorithm parameters

    description:
     Parses and returns algorithm parameters from a user-uploaded TOML file.
     Used by the frontend to dynamically update algorithm configuration.
    """
    
    try:
        
        # get the file from the request
        #
        file = request.files[JSON_KEY_FILE]

        # read the file and use toml parser
        #
        content = file.read().decode(DEFAULT_FILE_TYPE)
        toml_data = toml.loads(content)

        # extract the algorithm data
        #
        algo_key = next(iter(toml_data))
        algo_data = toml_data.get(algo_key, {})

        # format the response
        #
        response = {
            JSON_KEY_ALG_NAME: algo_data.get(JSON_KEY_ALG_NAME), # extract the name dynamically
            JSON_KEY_PARAMS: algo_data  # extract the params dynamically
        }

        # exit gracefully: return the jsonifyied response
        #
        return jsonify(response)

    # handle any exceptions and return an error message
    #
    except Exception as e:
        return f'Failed to load algorithm parameters: {e}', 500

@main.route('/api/load_model/', methods=['POST'])
def load_model():
    """
    function: load_model

    arguments:
     None (input comes from POST request)

    return:
     JSON response containing decision surface data and label mapping

    description:
     Loads a user-uploaded pickled model, caches it under the user's ID,
     and generates a decision surface for the given x/y ranges and points.
    """
    
    try:
        
        # get the file, userID, and plot bounds from the request
        #
        file = request.files[JSON_KEY_MODEL]
        user_ID = request.form.get(JSON_KEY_USER_ID)
        xrange = json.loads(request.form.get(JSON_KEY_XRANGE))
        yrange = json.loads(request.form.get(JSON_KEY_YRANGE))

        # read the model file
        #
        model_bytes = file.read()

        # load the model
        #
        model, mapping_label = imld.load_model(model_bytes)

        # save the model to the corresponding userID
        #
        model_cache[user_ID] = {
            CACHE_KEY_MODEL: model,
            CACHE_KEY_TIMESTAMP: datetime.now()
        }

        # get the x y and z values from the decision surface
        # x and y will be 1D and z will be 2D
        #
        x, y, z = imld.generate_decision_surface(model,
                                                 xrange=xrange,
                                                 yrange=yrange,
                                                 classes=list(set(labels)))

        # if no mapping label is provided, create a default one
        #
        if mapping_label is None:
            mapping_label = {i: f'Class {i+1}' for i in set(z.flatten())}

        # format the response
        #
        response = {
            JSON_KEY_DECISION_SURFACE: {
                JSON_KEY_X: x.tolist(), 
                JSON_KEY_Y: y.tolist(), 
                JSON_KEY_Z: z.tolist()
            },
            JSON_KEY_MAPPING_LABEL: mapping_label
        }

        # exit gracefully: return the jsonified response
        #
        return jsonify(response)

    # handle any exceptions and return an error message
    #
    except Exception as e:
        response = {JSON_KEY_ERROR: str(e)}
        return jsonify(response), 500

@main.route('/api/save_alg_params/', methods=['POST'])
def save_alg_params():
    """
    function: save_alg_params

    arguments:
     None (input comes from POST request)

    return:
     a downloadable TOML file containing the algorithm parameters

    description:
     Accepts algorithm name and parameters from the frontend, structures
     them into a TOML-compliant format, and returns the file for download.
    """
    
    try:
        
        # get the data from the request
        #
        data = request.get_json()

        # get the algo name and params
        #
        algo_name_raw = data.get(JSON_KEY_DATA, {}).get(JSON_KEY_NAME)
        params = data.get(JSON_KEY_DATA, {}).get(JSON_KEY_PARAMS)

        # replace spaces and symbols for TOML-compliant table name
        #
        algo_key = algo_name_raw.replace(" ", "_").replace("(", "") \
                   .replace(")", "").replace("-", "_")

        # build nested TOML structure
        #
        toml_data = {
            algo_key: {
                JSON_KEY_NAME: algo_name_raw,
                JSON_KEY_PARAMS: {}
            }
        }

        # iterate through params to populate toml file
        #
        for param_name, param_info in params.items():
            toml_data[algo_key][JSON_KEY_PARAMS][param_name] = {
                JSON_KEY_TYPE: param_info.get(JSON_KEY_TYPE, ""),
                JSON_KEY_DEFAULT: str(param_info.get(JSON_KEY_DEFAULT, ""))
            }

        # convert toml file to byte stream
        #
        toml_str = toml.dumps(toml_data)
        file_data = io.BytesIO(toml_str.encode(DEFAULT_FILE_TYPE))

        # exit gracefully: return the toml file
        #
        return send_file(
            file_data,
            mimetype=MIME_OCTET,
            as_attachment=True,
            download_name=DOWNLOAD_ALG_TOML
        )

    # handle any exceptions and return an error message
    #
    except Exception as e:
        return f'Failed to save algorithm parameters: {e}', 500

@main.route('/api/save_model/', methods=['POST'])
def save_model():
    """
    function: save_model

    arguments:
     None (input comes from POST request)

    return:
     a downloadable pickled model (.pkl) file

    description:
     Retrieves a cached model associated with a user ID, updates its label
     mappings, serializes it with pickle, and sends it as a downloadable file.
    """
    
    try:
        
        # get the data from the request
        #
        data = request.get_json()

        # get the user id
        #
        userID = data[JSON_KEY_USER_ID]

        # check if the model is in the cache
        #
        if userID not in model_cache or not model_cache[userID]:
            raise ValueError(
                f'User ID {userID} not found in cache '
                'or model is empty. Please refresh.'
            )

        # get the model from the cache
        #
        model = model_cache[userID][CACHE_KEY_MODEL]

        # serialize model with updated label mappings, store in buffer
        #
        model_bytes = imld.save_model(model, data[JSON_KEY_LABEL_MAPPINGS])

        # exit gracefully: 
        # send the pickled model as a response, without writing to a file
        # reopen the bytes before sending because ML Tools save_model() closes
        # the file pointer
        #
        return send_file(model_bytes, 
                         as_attachment=True, 
                         download_name=DOWNLOAD_MODEL_PKL, 
                         mimetype=MIME_OCTET)

    # handle any exceptions and return an error message
    #
    except Exception as e:
        response = {JSON_KEY_ERROR: str(e)}
        return jsonify(response), 500
    
@main.route('/api/train/', methods=['POST'])
def train():
    """
    function: train

    arguments:
     None (input comes from POST request)

    return:
     JSON response containing decision surface data, model evaluation metrics,
     and parameter output

    description:
     Accepts user data, algorithm parameters, and plotting data, then creates
     and trains a model. Generates a decision surface based on the trained
     model and returns metrics and parameter output.
    """

    # get the data from the request
    #
    data = request.get_json()

    # get the data and algorithm parameters
    #
    userID = data[JSON_KEY_USER_ID]
    params = data[JSON_KEY_PARAMS]
    algo = data[JSON_KEY_ALGO]
    x = data[JSON_KEY_PLOT_DATA][JSON_KEY_X]
    y = data[JSON_KEY_PLOT_DATA][JSON_KEY_Y]
    labels = data[JSON_KEY_PLOT_DATA][JSON_KEY_LABELS]
    xrange = data[JSON_KEY_XRANGE]
    yrange = data[JSON_KEY_YRANGE]

    try:

        # create the model given the parameters
        #
        model = imld.create_model(algo, params)

        # create the data object
        #
        data = imld.create_data(x, y, labels)

        # train the model
        #
        model, metrics, parameter_output = imld.train(model, data)

        # get the x y and z values from the decision surface
        # x and y will be 1D and z will be 2D
        #
        x, y, z = imld.generate_decision_surface(model, 
                                                 xrange=xrange,
                                                 yrange=yrange,
                                                 classes=list(set(labels)),
                                                 grid_size=data.data.shape[0])

        # format the response
        #
        response = {
            JSON_KEY_DECISION_SURFACE: {
                JSON_KEY_X: x.tolist(), 
                JSON_KEY_Y: y.tolist(), 
                JSON_KEY_Z: z.tolist()
            },
            JSON_KEY_METRICS: metrics,
            JSON_KEY_PARAMETER_OUTPUT: parameter_output
        }

        # save the model in the cache
        #
        model_cache[userID] = {
            CACHE_KEY_MODEL: model,
            CACHE_KEY_TIMESTAMP: datetime.now()
        }
        
        # exit gracefully: return the jsonified response
        #
        return jsonify(response)

    # handle any exceptions and return an error message
    #          
    except Exception as e:
        return jsonify(f'Failed to train model: {str(e)}'), 500
    
@main.route('/api/eval/', methods=['POST'])
def eval():
    """
    method: eval

    arguments:
     None (input comes from POST request)

    return:
     JSON response containing model evaluation metrics and parameter output

    description:
     Evaluates a trained model using the provided user data and returns
     evaluation metrics and parameter output.
    """

    # get the data from the request
    #
    data = request.get_json()

    # get the data and algorithm parameters
    #
    userID = data[JSON_KEY_USER_ID]
    x = data[JSON_KEY_PLOT_DATA][JSON_KEY_X]
    y = data[JSON_KEY_PLOT_DATA][JSON_KEY_Y]
    labels = data[JSON_KEY_PLOT_DATA][JSON_KEY_LABELS]

    try:

        # get the model from the cache
        #
        model = model_cache[userID][CACHE_KEY_MODEL]

        # create the data object
        #
        data = imld.create_data(x, y, labels)

        # evaluate the model
        #
        metrics, parameter_output = imld.predict(model, data)

        # format the response
        #
        response = {
            JSON_KEY_METRICS: metrics,
            JSON_KEY_PARAMETER_OUTPUT: parameter_output
        }

        # exit gracefully: return the jsonified response
        #
        return jsonify(response)

    # handle any exceptions and return an error message
    #
    except Exception as e:
        return jsonify(f'Failed to evaluate model: {str(e)}'), 500
    
@main.route('/api/set_bounds/', methods=['POST'])
def rebound():
    """
    function: rebound

    arguments:
     None (input comes from POST request)

    return:
     JSON response containing updated decision surface data (x, y, z values)

    description:
     Updates the bounds for the decision surface based on the provided x/y
     ranges and user data, and returns the updated decision surface data.
    """

    # get the data from the request
    #
    data = request.get_json()

    # get the data and algorithm parameters
    #
    userID = data[JSON_KEY_USER_ID]
    xrange = data[JSON_KEY_XRANGE]
    yrange = data[JSON_KEY_YRANGE]

    try:

        # get the model from the cache
        #
        model = model_cache[userID][CACHE_KEY_MODEL]

        # get the x y and z values from the decision surface
        # x and y will be 1D and z will be 2D
        #
        x, y, z = imld.generate_decision_surface(model,
                                                 xrange=xrange,
                                                 yrange=yrange,
                                                 classes=list(set(labels)))
        
        # format the response
        #
        response = {
            JSON_KEY_DECISION_SURFACE: {
                JSON_KEY_X: x.tolist(), 
                JSON_KEY_Y: y.tolist(), 
                JSON_KEY_Z: z.tolist()
            }
        }
        
        # exit gracefully: return the jsonified response
        #
        return jsonify(response)

    # handle any exceptions and return an error message
    #
    except Exception as e:
        return \
        jsonify(f'Failed to re-bound the decision surface: {str(e)}'), 500
    
@main.route('/api/normalize/', methods=['POST'])
def normalize():
    '''
    function: normalize

    arguments:
     None (input comes from POST request)

    return:
     JSON response containing normalized x and y values

    description:
     Accepts x and y values from the frontend, normalizes them to a specified
     range, and returns the normalized x and y values. 
    '''
    try:

        # get the data from the request
        #
        data = request.get_json()

        # get the data
        #
        x = data[JSON_KEY_PLOT_DATA][JSON_KEY_X]
        y = data[JSON_KEY_PLOT_DATA][JSON_KEY_Y]
        labels = data[JSON_KEY_PLOT_DATA][JSON_KEY_LABELS]
        xrange = data[JSON_KEY_BOUNDS][JSON_KEY_XRANGE]
        yrange = data[JSON_KEY_BOUNDS][JSON_KEY_YRANGE]
        method = data[JSON_KEY_METHOD]

        # normalize, denormalize, or renormalize the data
        #
        if method == DENORMALIZE_METHOD:
            x, y = imld.denormalize_data(x, y, xrange, yrange)
        elif method == NORMALIZE_METHOD:
            x, y = imld.normalize_data(x, y, xrange, yrange)
        elif method == RENORMALIZE_METHOD:
            old_xrange = data[JSON_KEY_OLD_BOUNDS][JSON_KEY_XRANGE]
            old_yrange = data[JSON_KEY_OLD_BOUNDS][JSON_KEY_YRANGE]
            x, y = imld.denormalize_data(x, y, old_xrange, old_yrange)
            x, y = imld.normalize_data(x, y, xrange, yrange)

        # prepare the response data
        #
        response_data = {
            JSON_KEY_LABELS: labels,
            JSON_KEY_X: x,
            JSON_KEY_Y: y
        }
            
        # exit gracefully: return the response in JSON format
        #
        return jsonify(response_data)
    
    # handle any exceptions and return an error message
    #    
    except Exception as e:
        return jsonify(f'Failed to normalize data: {str(e)}'), 500

@main.route('/api/data_gen/', methods=['POST'])
def data_gen():
    """
    function: data_gen

    arguments:
     None (input comes from POST request)

    return:
     JSON response containing generated data (labels, x, y values)

    description:
     Generates synthetic data based on the provided distribution name and
     parameters, and normalizes the data if requested, returning the generated
     labels and data points.
    """

    # get the data sent in the POST request as JSON
    #
    data = request.get_json()

    # extract the key and parameters from the received data
    #
    if data:
        dist_name = data[JSON_KEY_METHOD]
        paramsDict = data[JSON_KEY_PARAMS]

    try:

        # generate values for labels, x, y
        #
        labels, x, y = imld.generate_data(dist_name, paramsDict)

        # prepare the response data
        #
        response_data = {
            JSON_KEY_LABELS: labels,
            JSON_KEY_X: x,
            JSON_KEY_Y: y
        }

        # exit gracefully: return the response in JSON format
        #
        return jsonify(response_data)

    # handle any exceptions and return an error message
    #    
    except Exception as e:
        return jsonify(f'Failed to generate data: {str(e)}'), 500

@main.route('/api/issue_number/', methods=['POST'])
def write_issue():
    """
    function: issue_number

    arguments:
     None (input comes from POST request)

    return: None

    description:
     Get the issue number from the environment variable and return it as a JSON
     response.
    """
    
    try:

        # get the issue number from a environment variable that is stored
        # in the nedc_imld conda environment
        #
        issue_num = os.getenv(ENV_ISSUE_NUM)

        # if the issue number is not set, raise an error
        #
        if not issue_num: issue_num = DEFAULT_ISSUE_NUM
        
        # increment this issue number by 1 and store it back in the environment 
        # variable
        #
        os.environ[ENV_ISSUE_NUM] = f'{int(issue_num)+1:03d}'

        # exit gracefully: return the issue number
        #
        return jsonify({JSON_KEY_ISSUE_NUM: issue_num}), 200

    # handle any exceptions and return an error message
    #
    except Exception as e:  
        return jsonify(str(e)), 500

#
# end of file