#!/usr/bin/env python # # file: $ISIP_EXP/tuh_dpath/exp_0122/scripts/dataset.py # # revision history: # 20210424 (NS): first version # # usage: # # This script shows an example of a dataset class with the essential function # of PyTorch Datasets. Additional functions can be added to manipulate the # samples as needed such as adding noise, zero padding, resizing images etc. # #------------------------------------------------------------------------------ # import pytorch modules # import torch import torch.nn as nn from torch.utils.data import Dataset import numpy as np # global variables # NEW_LINE = "\n" # function: get_data # # arguments: fp - file pointer # num_feats - the number of features in a sample # # returns: data - the signals/features # labels - the correct labels for them # # this method takes in a fp and returns the data and labels # def get_data(fp, num_feats): # initialize the data and labels # data = [] labels = [] # for each line of the file # for line in fp.read().split(NEW_LINE): # split the string by white space # temp = line.split() # if we dont have 26 feats + 1 label # if not (len(temp) == num_feats + 1): continue # append the labels and data # labels.append(int(temp[0])) data.append([float(sample) for sample in temp[1:]]) # exit gracefully # return data, labels # # end of function # the dataset class # class CustomDataset(Dataset): # function: __init__ # # arguments: feats - features in the set # labels - corresponding labels # # returns: None # # this method initializes the CustomDataset class # def __init__(self, feats, labels): # set the data as the input # self.feats = feats # set the labels # self.labels = labels # # end of method # function: __len__ # # arguments: None # # returns: number of sample # # this method is a mandatory function of the Dataset # class that returns the number of samples # def __len__(self): return len(self.labels) # # end of method # function: __getitem__ # # arguments: index - the user does not need to be concerned # about this argument. PyTorch handles this. # # returns: label - label of the samples # feats - feature values # # this method is a mandatory function of the Dataset # class that returns the samples. Note # that the return values can be changed as needed. # For example, an autoencoder only require the sample # values, not the labels. # def __getitem__(self, index): # get features # feat = self.feats[index] feat = np.array(feat).reshape(1, len(feat)) # get label # label = self.labels[index] # return the label and features # return feat, label # # end of method # # end of class