#!/usr/bin/env python # # file: imld/data/imld_data_io.py # # revision history: # # 20210923 (TC): clean up. Added write_data() # 20200101 (..): initial version # # This class contains a collection of functions that deal with data handling # The structure of this class includes: # load_file() --> read_data() # save_file() --> write_data() # #------------------------------------------------------------------------------ # # imports are listed here # #------------------------------------------------------------------------------ # import system modules # import numpy as np from PyQt5 import QtWidgets import os # import nedc modules # import imld_constants_file as icf import nedc_ml_tools as ml #------------------------------------------------------------------------------ # # global variables are listed here # #------------------------------------------------------------------------------ # strings used in menu prompts # LOAD_TRAIN_DATA = "Load Train Data" LOAD_EVAL_DATA = "Load Eval Data" SAVE_TRAIN_DATA = "Save Train Data as..." SAVE_EVAL_DATA = "Save Eval Data as..." LOAD_MODEL = "Load Model" SAVE_MODEL = "Save Model as ..." LOAD_PARAM = "Load Parameters" SAVE_PARAM = "Save Parameters as ..." # strings used in csv formatted files # FILE_HEADER_INFO = ["classes", "colors", "limits"] TRAIN = "TRAIN" EVAL = "EVAL" MODEL = "MODEL" PARAM = "PARAM" EXT = "NAME" STR = "STR" NAME = "NAME" FILE = "FILE" # default filenames # DEFAULT_TRAIN_FNAME = "imld_train.csv" DEFAULT_EVAL_FNAME = "imld_eval.csv" # default model extensions # MODEL_EXTS = ".pkl" MODEL_EXTS_STR = "Binary Pickle Files (*.pkl)" # default data extensions # DATA = "csv" DATA_EXTS = ".csv" DATA_EXTS_STR = "CSV Formatted Files (*.csv)" DATA # default parameter extensions # PARAM_EXTS = ".txt" PARAM_EXTS_STR = "TXT ML Tools Formatted Parameter Files (*.txt)" #------------------------------------------------------------------------------ # # classes are listed here # #------------------------------------------------------------------------------ class IMLD_MLToolData(ml.MLToolData): def __init__(self, ui , _imld_data): """ function: from_imld argument: imld_data: data that is generated by IMLD return: a MLToolData object description: this function is a classmethod that creates a new MLToolData object from IMLD's data structure """ self.imld_data = _imld_data self.dir_path = "" self.lndx = 0 self.nfeats = -1 self.num_of_classes = len(self.imld_data) labels = [] data = [] mapping_label = {} # converting the data into our new format # for i, lists in enumerate(self.imld_data): mapping_label[i] = i labels.extend([i] * len(lists)) for item in lists: data.append(item) labels = np.asarray(labels) data = np.asarray(data) self.labels = labels self.data = data self.mapping_label = mapping_label return None # method: AlgorithmLDA::predict # # arguments: # ax: the canvas with the original data is plotted on # X: is the data that is being used for the predictions # # return: # xx: the x coordinates of the contour # yy: the y coordinates of the contour # Z: the height of the contour # # This method is used to make a prediction using the Mahalanobis distance # def predict_decision_surface(self, ax, alg): original_data = self.data X = self.data # Creates the mesh grid # X = np.concatenate(X, axis=0) X = np.reshape(X, (-1, 2)) res = (ax.canvas.axes.get_xlim()[1] - \ ax.canvas.axes.get_ylim()[0]) / 100 x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, res), np.arange(y_min, y_max, res)) self.data = np.c_[xx.ravel(),yy.ravel()] labels, _ = alg.predict(self) self.data = original_data labels = np.array(labels) Z = labels.reshape(xx.shape) # exit gracefully # return xx, yy, Z # class: DataIO # # This class contains methods to both save and load data for both # the train and eval windows of the application. # class DataIO: # method: DataIO::constructor # # arguments: # usage: a short explanation of the command that is printed when # it is run without argument. # help: a full explanation of the command that is printed when # it is run with -help argument. # # return: None # def __init__(self, ui ): # declare a data structure to hold user data # self.user_data = None # set the ui # self.ui = ui # exit gracefully # return None # method: DataIO::load_file # # arguments: # mode: train or eval # # return: # classes: list of classes from user input file # colors: list of colors from user input file # user_data: a dict of user data # def load_file(self, mode): # get mode, either train or eval # self.mode = mode # prompts user for file - displayed at the top of pop-up when user # clicks on File menu, depends on file being train or eval # check if user chooses train mode # if self.mode is icf.DEF_MODE[0]: # Reference: # static PySide2.QtWidgets.QFileDialog.getOpenFileName([parent=None[, # caption=""[, dir=""[, filter=""[, selectedFilter=""[, # options=QFileDialog.Options()]]]]]]) # See more: https://doc.qt.io/qtforpython-5/PySide2/QtWidgets/QFileDialog.html#PySide2.QtWidgets.PySide2.QtWidgets.QFileDialog.getOpenFileName # file, _ = QtWidgets.QFileDialog.getOpenFileName( self.ui, LOAD_TRAIN_DATA, icf.DELIM_NULL,) # check if user chooses eval mode # elif self.mode is icf.DEF_MODE[1]: # Reference: # Same as imld_data_io.load_file() # file,_ = QtWidgets.QFileDialog.getOpenFileName( self.ui, LOAD_EVAL_DATA, icf.DELIM_NULL,) # get the name of the save file as a string # fname = str(file) # make sure a valid file was selected # if (type(fname) is not str) or (not (len(fname) > 4)) \ or (not fname.endswith(".csv")): return None, None, None, None # read data # classes, colors, limits, self.user_data = self.read_data(fname) # check if data was read # if self.user_data is None: print("Warning: Data not read or updated") return None, None, None, None # exit gracefully # return classes, colors, limits, self.user_data # # end of method # method: DataIO::save_file # # arguments: # data: a dict of data wanted to save # mode: train or eval # # return: None # def save_file(self, data, mode, limits): # prompts user for name to save file as, depending on data # if mode is icf.DEF_MODE[0]: # Reference: # static PySide2.QtWidgets.QFileDialog.getSaveFileName([parent=None[, # caption=""[, dir=""[, filter=""[, selectedFilter=""[, # options=QFileDialog.Options()]]]]]]) # See more: https://doc.qt.io/qtforpython-5/PySide2/QtWidgets/QFileDialog.html#PySide2.QtWidgets.PySide2.QtWidgets.QFileDialog.getSaveFileName # file, _ = QtWidgets.QFileDialog.getSaveFileName(self.ui, SAVE_TRAIN_DATA, DEFAULT_TRAIN_FNAME, DATA_EXTS_STR) elif mode is icf.DEF_MODE[1]: # Reference: # Same as imld_data_io.save_file() # file, _ = QtWidgets.QFileDialog.getSaveFileName(self.ui, SAVE_EVAL_DATA, DEFAULT_EVAL_FNAME, DATA_EXTS_STR) else: pass fname = str(file) # make sure a file name was given # if (fname is icf.DELIM_NULL) or (len(fname) == 0): return # append file extension if user save name doesn't include exts # if not fname.endswith(".csv"): fname = os.path.splitext(fname)[0] + DATA_EXTS self.write_data(data, fname, limits) def load_model(self): # load in a file from a file selection menu # file, _ = QtWidgets.QFileDialog.getOpenFileName( self.ui, LOAD_MODEL, icf.DELIM_NULL) # convert the chosen file to a string for manipulation # fname = str(file) # make sure a valid file was selected # if (type(fname) is not str) or (not (len(fname) > 4)) \ or (not fname.endswith(".pkl")): return None # exit gracefully # return fname # # end of method def save_model(self, algo_name): # load the save file gui with a default filename of "imld_{algo}.pckl" # default_savename = f"imld_{algo_name.lower()}.pkl" file, _ = QtWidgets.QFileDialog.getSaveFileName(self.ui, SAVE_MODEL, default_savename, MODEL_EXTS_STR) # get the name of the save file as a string # fname = str(file) # make sure a valid file was selected # if (type(fname) is not str) or (not (len(fname) > 4)) \ or (not fname.endswith(".pkl")): return None # exit gracefully # return fname # # end of method def load_params(self): # load in a file from a file selection menu # file, _ = QtWidgets.QFileDialog.getOpenFileName( self.ui, LOAD_PARAM, icf.DELIM_NULL) # convert the chosen file to a string for manipulation # fname = str(file) # make sure a valid file was selected # if (type(fname) is not str) or (not (len(fname) > len(PARAM_EXTS))) \ or (not fname.endswith(PARAM_EXTS)): return None # exit gracefully # return fname def save_params(self, algo_name): # load the save file gui with a default filename of "imld_{algo}.pckl" # default_savename = f"imld_{algo_name.lower()}_param{PARAM_EXTS}" file, _ = QtWidgets.QFileDialog.getSaveFileName(self.ui, SAVE_PARAM, default_savename, PARAM_EXTS_STR) # get the name of the save file as a string # fname = str(file) # make sure a valid file was selected # if (type(fname) is not str) or (not (len(fname) > len(PARAM_EXTS))) \ or (not fname.endswith(PARAM_EXTS)): return None # exit gracefully # return fname # # end of method # method: DataIO::read_data # # arguments: # fname: input filename # # return: # classes: list of classes from user input file # colors: list of colors from user input file # user_data: a dict of user data # @staticmethod def read_data(fname): classes = [] colors = [] limits = [] data = [] # open file # with open(fname, icf.MODE_READ_TEXT) as fp: # loop over lines in file # for num_line, line in enumerate(fp): # clean up the line # line = line.replace(icf.DELIM_NEWLINE, icf.DELIM_NULL) \ .replace(icf.DELIM_CARRIAGE, icf.DELIM_NULL) check = line.replace(icf.DELIM_SPACE, icf.DELIM_NULL) # get classes in csv file # if check.startswith(icf.DELIM_COMMENT + FILE_HEADER_INFO[0] ): # get classes after colon # check = check.split(icf.DELIM_COLON)[1]\ .replace(icf.DELIM_OPEN, icf.DELIM_NULL)\ .replace(icf.DELIM_CLOSE, icf.DELIM_NULL) # split to list # classes = check.split(icf.DELIM_COMMA) continue # get colors in csv file # if check.startswith(icf.DELIM_COMMENT + FILE_HEADER_INFO[1]): # get colors after colon # check = check.split(icf.DELIM_COLON)[1].replace(icf.DELIM_OPEN, icf.DELIM_NULL) \ .replace(icf.DELIM_CLOSE, icf.DELIM_NULL) # split to list # colors = check.split(icf.DELIM_COMMA) continue # get limits in csv file # if check.startswith(icf.DELIM_COMMENT + FILE_HEADER_INFO[2]): # get limits after colon # check = check.split(icf.DELIM_COLON)[1].replace(icf.DELIM_OPEN, icf.DELIM_NULL) \ .replace(icf.DELIM_CLOSE, icf.DELIM_NULL) # split to list # limits = check.split(icf.DELIM_COMMA) continue # get data # if not check.startswith(icf.DELIM_COMMENT): try: class_name, x, y = check.split(icf.DELIM_COMMA)[0:3] data.append([class_name, float(x), float(y)]) except: print("Error loading at line %d" % num_line) return None continue # # end of for # convert list of data to a dict # user_data = {} for item in data: if item[0] not in user_data: user_data[item[0]] = [] user_data[item[0]].append([np.array(item[1]), np.array(item[2])]) # # end of with open file # exit gracefully # return classes, colors, limits, user_data # # end of method # method: DataIO::write_data # # arguments: # data: data wanted to be save. The structure is same as class_info. # data = {classname: [[...], [X], [Y], [...], [color]], etc} # fname: filename wanted to be save as # # return: None # @staticmethod def write_data(data, fname, limits): # creates the file and writes the data to it # with open(fname, icf.MODE_WRITE_TEXT) as fp: # import all class_info from InputDisplay # class_info = data # get list of colors # colors = [] for class_name in class_info: colors.append(class_info[class_name][4]) colors = icf.DELIM_OPEN \ + icf.DELIM_COMMA.join(str(e) for e in colors) \ + icf.DELIM_CLOSE classes = icf.DELIM_OPEN \ + icf.DELIM_COMMA.join(str(e) for e in \ list(class_info.keys())) \ + icf.DELIM_CLOSE limits = icf.DELIM_OPEN + icf.DELIM_COMMA.join(str(e) for e in limits) + icf.DELIM_CLOSE # write comment with classes and colors lists # fp.write("# filename: %s\n" % fname) fp.write("# classes: %s\n" % classes) fp.write("# colors: %s\n" % colors) fp.write("# limits: %s\n" % limits) fp.write("# \n") for class_name in class_info: # retrieve class coordinate points # data_x = class_info[class_name][1] data_y = class_info[class_name][2] data_t = np.column_stack((data_x,data_y)) # write each point of each class in a csv file # for item in range(len(data_t)): # write "classname,"" # fp.write(class_name + icf.DELIM_COMMA) # write "x,y" # for num in range(len(data_t[item])): if data_t[item][num] == data_t[item][-1]: fp.write("%8lf" % (data_t[item][num])) else: fp.write("%8lf," % (data_t[item][num])) # write new line # fp.write(icf.DELIM_NEWLINE) # close file # fp.close() # exit gracefully # return True # # end of method # # end of class # # end of file