#!/usr/bin/env python # # file: $ISIP_EXP/tuh_dpath/exp_0074/scripts/train.py # # revision history: # 20210424 (NS): added dataset and dataloader # 20190925 (TE): first version # # usage: # python train.py mdir data # # arguments: # mdir: the directory where the output model is stored # data: the input data list # # This script trains a simple MLP model #------------------------------------------------------------------------------ # import pytorch modules # import torch import torch.nn as nn from torch.optim import Adam from torch.utils.data import DataLoader # import sklearn to split data into multiple sets # from sklearn.model_selection import train_test_split # import the modules and all of their variables/functions # from model import * from dataset import * # import modules # import sys import os import copy #----------------------------------------------------------------------------- # # global variables are listed here # #----------------------------------------------------------------------------- # general global values # NUM_ARGS = 2 NUM_EPOCHS = 10 BSIZE = 8 LEARNING_RATE = "lr" BETAS = "betas" EPS = "eps" WEIGHT_DECAY = "weight_decay" BATCH_SIZE = "batch_size" SHUFFLE = "shuffle" NWORKERS = "num_workers" VAL_LOSS = 10000 # for reproducibility, we seed the rng # #set_seed(SEED1) SEED = 1337 os.environ['PYTHONHASHSEED'] = str(SEED) # Torch RNG torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) torch.cuda.manual_seed_all(SEED) # Python RNG np.random.seed(SEED) random.seed(SEED) #------------------------------------------------------------------------------ # # the main program starts here # #------------------------------------------------------------------------------ # function: main # # arguments: none # # return: none # # This method is the main function. # def main(argv): # ensure we have the correct amount of arguments # if(len(argv) != NUM_ARGS): print("usage: python nedc_train_mdl.py [MDL_PATH] [TRAIN_SET]") exit(-1) # define local variables # mdl_path = argv[0] fname = argv[1] num_feats = DEF_NUM_FEATS if("DL_NUM_FEATS" in os.environ): num_feats = int(os.environ["DL_NUM_FEATS"]) # get the output directory name # odir = os.path.dirname(mdl_path) # if the odir doesn't exits, we make it # if not os.path.exists(odir): os.makedirs(odir) # set the device to use GPU if available # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # get a file pointer # try: train_fp = open(fname, "r") except (IOError) as e: print("[%s]: %s" % (fname, e.strerror)) exit(-1) # get array of the data # data, labels = get_data(train_fp, num_feats) # split the data into train and validation sets # train_data, test_data, train_labels, test_labels = \ train_test_split(data, labels, test_size = 0.25) # get the train and test size # train_len = len(train_data) test_len = len(test_data) # intialize the train and test set # train_set = CustomDataset(train_data,train_labels) test_set = CustomDataset(test_data, test_labels) # parameters for the dataloader # dl_params = { BATCH_SIZE: BSIZE, SHUFFLE: True, NWORKERS: 1, } # pass the dataset objects to the DataLoader to # get a iterable dataset # train_loader = DataLoader(train_set, **dl_params) test_loader = DataLoader(test_set, **dl_params) #f, l = next(iter(train_loader)) # close the file # train_fp.close() # instantiate a model # model = Model(num_feats, NUM_NODES, NUM_CLASSES) # moves the model to device (cpu in our case so no change) # model.to(device) # set the adam optimizer parameters # opt_params = { LEARNING_RATE: 0.005, BETAS: (.9,0.999), EPS: 1e-08, WEIGHT_DECAY: .00001 } # set the loss and optimizer # loss_fx = nn.CrossEntropyLoss() loss_fx.to(device) # create an optimizer, and pass the model params to it # adam_opt = Adam(model.parameters(), **opt_params) # get the number of epochs to train on # epochs = NUM_EPOCHS # values for informational message # num_batches = len(train_data) // BSIZE # validation loss for weight selection # val_loss = VAL_LOSS best_wts = copy.deepcopy(model.state_dict()) # for each epoch # for epoch in range(epochs): # set the model in training mode # model.train() is_train = True train_loss = 0 index = 0 # get data from the dataloader # for feats, lbls in train_loader: # set all gradients to 0 # adam_opt.zero_grad() # collect the samples and send them to device # feats = feats.float().to(device) lbls = lbls.long().to(device) # train the model with tracking history # with torch.set_grad_enabled(is_train): # feed the network the batch # output = model(feats) #print(output.size(), lbls) # get the loss # loss = loss_fx(output.squeeze(1), lbls) # add to the current loss # train_loss += loss.item() * feats.size(0) # perform back propagation # loss.backward() adam_opt.step() # display informational message # print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}' .format(epoch + 1, epochs, index + 1, num_batches, loss.item())) # increment the batch number # index += 1 # print train loss # #print ("Epoch [{}/{}]: ".format(epoch + 1, epochs)) train_loss = train_loss / train_len #print ("train_loss: ", train_loss) # validation step # set model to eval mode # model.eval() is_train = False # local variable for adding current validation loss # cur_val_loss = 0 # train the model # for feats, lbls in test_loader: # collect the samples and send them to device # feats = feats.float().to(device) lbls = lbls.long().to(device) # the model will not learn in the evaluation mode # with torch.no_grad(): # feed the network the batch # output = model(feats) #print(output.size(), lbls) # get the loss # loss = loss_fx(output.squeeze(1), lbls) # add to the current loss # cur_val_loss += loss.item() * feats.size(0) # get total validation loss # cur_val_loss = cur_val_loss / test_len #print ("validation_loss: ", cur_val_loss) # save current weights if the current validation is less # than before # if (cur_val_loss < val_loss): val_loss = cur_val_loss best_model_wts = copy.deepcopy(model.state_dict()) # save the model # print("Best validation loss: ", val_loss) model.load_state_dict(best_model_wts) torch.save(model.state_dict(), mdl_path) # exit gracefully # return True # # end of function # begin gracefully # if __name__ == '__main__': main(sys.argv[1:]) # # end of file