#!/usr/bin/env python
#
# file: nedc_dpath_image_tools.py
#
# revision history:
#
#  20250529 (DH): updated DecodeImages to accept external coords list
#
# This module provides utilities for reading whole-slide images (WSI) and
# generating patch-based datasets for training and decoding tasks.
#------------------------------------------------------------------------------

# import system modules
#
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

# import NEDC modules
#
from nedc_image_tools import Mil, Pil

#-------------------------------------------------------------------------------
#
# classes listed here
#
#-------------------------------------------------------------------------------

class DecodeImage(Dataset):
    """
    class: DecodeImage

    description:
      PyTorch Dataset that serves image patches sequentially. Each worker opens 
      its own Mil instance lazily to avoid sharing handles across processes.
    """
    
    def __init__(self, slide_path, coords, patch_size, transform=None):
        """
        method: __init__

        arguments:
         slide_path: str, path to the image file.
         coords: list of (x, y) top-left coordinates for patches.
         patch_size: int, side length of each patch.
         transform: optional torchvision transforms pipeline.

        return: None

        description:
         Store parameters safe to pickle. Do not open Mil here;
         do so per worker.
        """
        # call superclass constructor
        #
        super().__init__()

        # store slide file path
        #
        self.slide_path = slide_path

        # store list of top-left patch coordinates
        #
        self.coords = coords

        # store patch side length
        #
        self.patch_size = patch_size

        # store transform pipeline
        #
        self.transform = transform

        # per-worker state: Img handle (None until first __getitem__)
        #
        self._img = None

    #
    # end of method

    def __len__(self):
        """
        method: __len__

        arguments:
         none

        return: number of patches

        description:
         Provide required length for PyTorch Dataset.
        """
        # return total number of patches
        #
        return len(self.coords)

    #
    # end of method

    def __getitem__(self, idx):
        """
        method: __getitem__

        arguments:
         idx: int, index of desired patch

        return:
         img: transformed image patch

        description:
         Lazily open Img per worker, read patch at coordinate idx,
         apply optional transforms, and return image.
        """
        # if Img not yet opened in this worker, open it here
        #
        if self._img is None:

            # create Img reader instance
            #
            if Mil().is_mil(self.slide_path):
                self._img = Mil()
            else:
                self._img = Pil()

            # open slide
            #
            self._img.open(self.slide_path)

        # unpack x, y coordinate for this index
        #
        x, y = self.coords[idx]

        # read NumPy array patch of size (patch_size x patch_size)
        #
        patch_np = self._img.read_data(
            coordinate=(x, y),
            npixx=self.patch_size,
            npixy=self.patch_size
        )

        # convert NumPy array (uint8) to PIL Image
        #
        img = Image.fromarray(patch_np.astype(np.uint8))

        # apply transform pipeline if provided
        #
        if self.transform is not None:
            img = self.transform(img)

        # return image patch
        #
        return img

    #
    # end of method

    def __del__(self):
        """
        method: __del__

        arguments:
         none

        return:
         none

        description:
         Close Img reader when Dataset is garbage-collected or worker exits.
        """
        # attempt to close Img if opened
        #
        try:
            if self._img is not None:
                self._img.close()
                
        except Exception:

            # ignore errors during close
            #
            pass

    #
    # end of method

#
# end of class

class TrainImages(Dataset):
    """
    class: TrainImages

    description:
     PyTorch Dataset that serves labeled patches while keeping one
     slide reader per DataLoader worker. The actual Img handle is
     created lazily inside each worker so that cuCIM / OpenSlide pointers
     are never fork-shared or pickled.
    """

    def __init__(self, samples, frmsize, window_size, transform=None, index_map=None):
        """
        method: constructor

        arguments:
         samples: list of (img_path, coord, label)
         frmsize: the frame size
         window_size: int, window edge length in pixels
         transform: optional torchvision transform applied to every patch
         index_map: currenlty optional, expects dict = {bckg: 0, ...} sets
          models output integer to label string mapping
        
        return:
         none

        description:
         Store metadata that is safe to pickle. The Img reader itself
         is not created here because the whole object will be copied to
         every worker when num_workers > 0.
        """
        super().__init__()

        # save constructor parameters
        #
        self.samples = samples
        self.window_size = window_size
        self.frmsize = frmsize
        self.transform = transform
        
        # per-worker state (initialised on first access inside the worker)
        #
        self._img = None
        self._current_path = None

        # build classes sorted list
        #
        self.classes = sorted(index_map.keys())
            
        # build label index map and numeric target list
        #
        self.index_map = index_map
                          
        # set targets
        #
        self.targets = [self.index_map[lbl] for _,_,lbl in samples]

    #
    # end of method

    #--------------------------------------------------------------------------
    #
    # internal helper methods
    #
    #--------------------------------------------------------------------------

    def _ensure_open(self, img_path):
        """
        method: _ensure_open

        arguments:
         img_path: path to the slide that must be open

        return: None

        description:
         Lazily allocates a Img reader for the current worker and opens
         the requested slide if it is not already open. This guarantees that
         every worker has its own handle and avoids sharing cuCIM /
         OpenSlide pointers across processes.
        """

        # allocate reader once per worker
        #
        if self._img is None:
            if Mil().is_mil(img_path) == True:
                self._img = Mil()
            else:
                self._img = Pil()

        # (re)open slide when path changes
        #
        if img_path != self._current_path:
            if self._current_path is not None:
                try:
                    self._img.close()
                except Exception:
                    pass
            self._img.open(img_path)
            self._current_path = img_path
    #
    # end of method

    #--------------------------------------------------------------------------
    #
    # Dataset interface methods
    #
    #--------------------------------------------------------------------------

    def __len__(self):
        """
        method: __len__

        arguments:
         none

        return:
         int, number of samples in the dataset

        description:
         Provide required length for the PyTorch Dataset interface.
        """
        return len(self.samples)
    #
    # end of method

    def __getitem__(self, idx):
        """
        method: __getitem__

        arguments:
         idx: index of the desired sample

        return:
         tuple (image, label_idx)

        description:
         Ensure the correct slide is open in the current worker, extract the
         patch at the requested coordinate, apply optional transforms, and
         return the image together with its numeric class label.
        """

        # unpack sample
        #
        img_path, coord, lbl = self.samples[idx]

        # make sure this worker has the right slide open
        #
        self._ensure_open(img_path)

        x, y = coord

        # read raw patch data
        #
        patch = self._img.read_data(
            (x, y),
            npixx=self.window_size,
            npixy=self.window_size
        )

        # convert to PIL for torchvision transforms
        #
        img = Image.fromarray(patch.astype(np.uint8))

        # apply user-specified transforms (if any)
        #
        if self.transform is not None:
            img = self.transform(img)

        # exit gracefully
        #
        return img, self.index_map[lbl]
    #
    # end of method

    def __del__(self):
        """
        method: __del__

        arguments:
         none

        return:
         none

        description:
         Close Img reader when Dataset is garbage-collected or worker exits.
        """
        # attempt to close Img if opened
        #
        try:
            if self._img is not None:
                self._img.close()

        except Exception:

            # ignore errors during close
            #
            pass

    #
    # end of method

#
# end of class

#
# end of file
