#!/usr/bin/env python # # file: $ISIP_EXP/tuh_dpath/exp_0074/scripts/nedc_train_mdl.py # # revision history: # 20190925 (TE): first version # # usage: # python nedc_decode_mdl.py -p params -o odir # # This script decodes a simple MLP model #------------------------------------------------------------------------------ # import pytorch modules # import torch # import the model # from model import Model # import modules # import numpy as np import random import sys import os #----------------------------------------------------------------------------- # # global variables are listed here # #----------------------------------------------------------------------------- # general global values # NUM_FEATS = 26 NUM_NODES = 26 NUM_CLASSES = 2 NUM_ARGS = 3 SEED1 = 1337 NEW_LINE = "\n" SPACE = " " HYP_EXT = ".hyp" #------------------------------------------------------------------------------ # # helper function listed here # #------------------------------------------------------------------------------ # function: set_seed # # arguments: seed - the seed for all the rng # # returns: none # # this method seeds all the random number generators and makes # the results deterministic # def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(seed) random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) set_seed(SEED1) # function: get_data # # arguments: fp - file pointer # # returns: data - the signals/features # labels - the correct labels for them # # this method takes in a fp and returns the data and labels # def get_data(fp): # initialize the data and labels # data = [] labels = [] # for each line of the file # for line in fp.read().split(NEW_LINE): # split the string by white space # temp = line.split() # if we dont have 26 feats + 1 label # if not (len(temp) == NUM_FEATS + 1): continue # append the labels and data # labels.append(int(temp[0])) data.append([float(sample) for sample in temp[1:]]) # close the file # fp.close() # exit gracefully # return data, labels # # end of function #------------------------------------------------------------------------------ # # the main program starts here # #------------------------------------------------------------------------------ # function: main # # arguments: none # # return: none # # This method is the main function. # def main(argv): # ensure we have the correct number of arguments # if(len(argv) != NUM_ARGS): print("usage: python nedc_decode_mdl.py [ODIR] [EVAL_SET] [MDL_PATH]") exit(-1) # define local variables # odir = argv[0] fname = argv[1] mdl_path = argv[2] # get the hyp file name # hyp_name = os.path.splitext(os.path.basename(fname))[0] + HYP_EXT # set the device to use GPU if available # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # get a file pointer # try: eval_fp = open(fname, "r") except (IOError, KeyError) as e: print("[%s]: %s" % (fname, e)) exit(-1) # get array of the data # data: [[0, 1, ... 26], [27, 28, ...] ...] # labels: [0, 0, 1, ...] # eval_data, _ = get_data(eval_fp) # instantiate a model # model = Model(NUM_FEATS, NUM_NODES, NUM_CLASSES) # moves the model to the device # model.to(device) # set the model to evaluate # model.eval() # load the weights # model.load_state_dict(torch.load(mdl_path, map_location=device)) # the output file # try: ofile = open(os.path.join(odir, hyp_name), 'w+') except IOError as e: print(os.path.join(odir, hyp_name)) print("[%s]: %s" % (hyp_name, e.strerror)) exit(-1) # get the number of data points # num_points = len(eval_data) # for each data point # for index, data_point in enumerate(eval_data): # print informational message # print("decoding %4d out of %d" % (index+1, num_points)) # pass the input through the model # output = model(torch.tensor(data_point, dtype=torch.float32).to(device)) # write the highest probablity to the file # ofile.write(str(int(output.max(0)[1])) + SPACE + SPACE.join([str(point) for point in data_point]) + NEW_LINE) # close the file # ofile.close() # exit gracefully # return True # # end of function # begin gracefully # if __name__ == '__main__': main(sys.argv[1:]) # # end of file