#!/usr/bin/env python
#
# file: $NEDC_NFC/src/class/nedc_dpath_slide.py
#
# revision history:
#
# 20230929 (SM): prep for v2.0.0 release
# 20211224 (PM): refactored code
# 20210416 (VK): initial release
#
# Usage:
#  import nedc_dpath_slide as nds
#
# This file contains a class for making images dataset from a single SVS slide.
# A very simple dataset similar to ImageFolder, but can work with a slide.
#------------------------------------------------------------------------------

# import system modules
#
import os
import sys
import numpy as np
import openslide
import torch
import torch.utils as utils

class Slide(utils.data.Dataset):
    '''
    class: Slide

    description:
     A very simple dataset to extract patches from a single SVS slide.
     Note: this class always truncates the slide for patch extraction if the
           width or heght is not a integer multiplier of frame length
    '''

    # method: __init__ constructor
    #
    # arguments:
    #   svs_fname: the name of svs file
    #   win_len: window length to make the patches
    #   frm_len: frame length to make the patches
    #   transform: all the PyTorch Vision transformations that should be
    #              done on every patch.
    #
    # return:
    #   None
    #
    def __init__(self, svs_fname, win_len, frm_len, transform=None):
        '''
        method: Slide::constructor

        arguments:
         svs_fname: the name of a svs file
         win_len: window length to make the patches
         frm_len: frame length to make the patches
         transform: all the PyTorch vision transformations should be done on 
                    every patch
        
        return:
         None
        '''
        
        # opening the svs file
        #
        self.slide = openslide.open_slide(svs_fname)
        self.wlen = win_len
        self.flen = frm_len
        self.transform = transform

        # compute some parameters
        #
        self.width, self.height = self.slide.dimensions
        self.n_horizontal = int((self.width - self.wlen + self.flen) // self.flen)
        self.n_vertical = int((self.height - self.wlen + self.flen) // self.flen)
        self.nwindows = self.n_horizontal * self.n_vertical

    def __len__(self):
        '''
        method: Slide::len

        arguments:
         None

        return:
         nwindows: the number of windows in the slide
        '''

        # exit gracefully
        #
        return(self.nwindows)
    #
    # end of method

    def __getitem__(self, indx):
        '''
        method: Slide::getitem

        arguments:
         index: the index of the item to get

        return:
         (reminder, multiplier): a tuple containing the reminder and multiplier
                                 values
         topleft_coords: the coordinates of the top-left of the window
         image: the image matrix
        '''

        # check the index
        #
        if indx > self.nwindows:
            assert('Error: index out of range')
            return None
        
        # find the coordinates of up-left rectangular in openslide coordinate
        # system
        #
        multiplier = indx // self.n_horizontal
        reminder = indx % self.n_horizontal
        topleft_coords = (reminder * self.flen, multiplier * self.flen)

        # read region from level 0
        #
        image = self.slide.read_region(topleft_coords, 0,
                                       (self.wlen, self.wlen)).convert('RGB')

        # transform
        if self.transform is not None:
            image = self.transform(image)
        
        # exit gracefully
        #
        return (reminder, multiplier), topleft_coords, image
    #
    # end of method
    
#
# end of class

#
# end of file
