#!/usr/bin/env python
#
# file: $NEDC_NFC/src/class/nedc_dpath_image.py
#
# revision history:
#
# 20212412 (PM): refactored code
# 20210326 (VK): initial release
#
# Usage:
#  import nedc_dpath_image as ndi
#
# This file contains a class for making images dataset
# A very simple dataset similar to ImageFolder, but can work with lists
#------------------------------------------------------------------------------

# import system modules
#
import os
import sys
import numpy as np
import PIL
import torch
import torch.utils as utils

#------------------------------------------------------------------------------
#
# classes are listed here
#
#------------------------------------------------------------------------------

class ImagesList(utils.data.Dataset):
    """
    class: ImageList

    arguments:
     fname: the name of csv file. first column should be labels and second
            column should be files path
     transform: all the transformations that should be done on images

    return:
     None

    description:
     A very simple dataset similar to ImageFolder, but can work with lists
    """

    def __init__(self, fname, transform=None):
        """
        method: ImagesList::constructor

        arguments:
         fname: the file name to a nedc_dpath_extract_patch generated
                patches.csv
         transform: all the PyTorch Vision transformations that should be
                    done on every patch (default=None)

        return:
         None
        """
        
        # reading csv file
        #
        with open(fname, mode='r') as f:
            flist = f.readlines()
        self.lbfn = [[lb.strip(), fname.strip()]
                    for (lb, fname) in [line.split(',') for line in flist]]
        self.transform = transform

        # compute the parameters
        #
        self.labels = [line[0] for line in self.lbfn]
        self.classes = (list(set(self.labels)))
        self.classes.sort()
        
        
        self.targets = []
        for lb in self.labels:
            self.targets.append(self.classes.index(lb))

        # get a list of .tif file names
        self.patch_list = [os.path.splitext(os.path.basename(line[1]))[0] 
                            for line in self.lbfn]
    
        # create a list of images
        #
        self.image_set = set([x[0:(x.rfind(("_"), 0, x.rfind("_")))]
                           for x in self.patch_list])        

        # create an image dictionary with the key being the image, and its value
        # being a list of correspinding files from self.lbfn
        #
        self.image_dict = {}
        for image in self.image_set:
            self.image_dict[image] = []
            for lbfn in self.lbfn:
                if image in lbfn[1]:
                    self.image_dict[image].append(lbfn)

        
        # exit gracefully
        #
        return None
    #
    # end of method

    def __len__(self):
        """
        method: ImagesList::__len__

        arguments:
         self

        return:
         len(self.lbfn): the length of the image list
        """

        # exit gracefully
        #
        return(len(self.lbfn))
    #
    # end of method

    # method: ImagesList::__getitem__
    #
    # arguments:
    #  indx: 
    #
    # return:
    #
    def __getitem__(self, indx):
        """
        method: ImageList::__getitem__

        arguments:
         indx: the index to get the item from

        return:
         image: the image of the item
         lbl_val: the label of the item
         fname: the filename of the item
        """

        label = self.lbfn[indx][0]
        lbl_val = self.classes.index(label)
        fname = self.lbfn[indx][1]
        image = PIL.Image.open(fname).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        # exit gracefully
        #
        return (image, lbl_val, fname)
    #
    # end of method

#
# end of class

#
# end of file
