#!/usr/bin/env python
#
# file: nedc_dpath_tissue_mask_tools.py
#
# revision history:
#
#  20250709 (SP): add masking utilities
#
# This module provides utilities for tissue masking whole-slide images (WSI).
# It is based on the Otsu's method for thresholding
# and morphological operations for small region removal and dilation. The
# following classes implement these methods.
# TissueMasker: Base class for tissue maskers.
# OtsuTissueMasker: Tissue masker using Otsu's method.
# MorphologicalMasker: Tissue masker using Otsu's method followed by
#                       morphological operations for small region removal
#                       and dilation.
# Majority of the code is adopted from the tiatoolbox library. 
#
# Reference: Pocock J, Graham S, Vu QD, Jahanifar M, Deshpande S
# Hadjigeorghiou G, Shephard A, Bashir RMS, Bilal M, Lu W, Epstein D, Minhas
# F, Rajpoot NM, Raza SEA. TIAToolbox as an end-to-end library for advanced
# tissue image analytics. Commun Med (Lond). 2022 Sep 24;2:120. doi: 10.1038
# s43856-022-00186-5. PMID: 36168445; PMCID: PMC9509319.
#
# Github: https://github.com/TissueImageAnalytics/tiatoolbox
#------------------------------------------------------------------------------

# import system modules
#
from abc import ABC, abstractmethod
from typing import List, Tuple, Union
import time
import numpy as np
from PIL import Image
from pathlib import Path
import openslide
import cv2


# import NEDC modules
#
import nedc_debug_tools as ndt
import nedc_file_tools as nft

# define static variables for debug
#
dbgl_d = ndt.Dbgl()

#------------------------------------------------------------------------------
#
# default values are set here
#
#------------------------------------------------------------------------------

DEF_BASELINE_MPP           = 0.25
DEF_INIT_OSTU_THRESHOLD    = 0
DEF_KERNEL_SIZE            = (5, 5)
DEF_MAX_BINARY_VALUE       = 255
DEF_MIN_REGION_SIZE        = 1024
DEF_MORF_CONNECTIVITY      = 8
DEF_MPP_CONVERSION_FACTOR  = 10.0
DEF_OBJECTIVE_POWER        = 1.25
DEF_THUMB_MAX_SIZE         = 2048

#------------------------------------------------------------------------------
#
# classes listed here
#
#------------------------------------------------------------------------------

def run(
    fpath: List[str],
    thumb_max_size: int = DEF_THUMB_MAX_SIZE,
    objective_power: float = DEF_OBJECTIVE_POWER,
    baseline_mpp: float = DEF_BASELINE_MPP,
    thumbnail_max_size: int = DEF_THUMB_MAX_SIZE,
    kernel_size: Union[int, Tuple[int, int], np.ndarray] = DEF_KERNEL_SIZE,
    min_region_size: int = DEF_MIN_REGION_SIZE,
) -> np.ndarray:
    """function: run
    arguments:
     fpath: List[str], list of file paths to whole-slide images.
     thumb_max_size: int, maximum size of the thumbnail image.
     objective_power: float or tuple of floats, objective power(s) used
                      to calculate the target MPP.
     baseline_mpp: float, baseline microns per pixel (MPP) for the slide.
    return: np.ndarray, array of tissue masks corresponding to the input images.
    description:
     Run the tissue masking process on the provided whole-slide images and
     return the generated tissue masks.
    """
    
    # set the status to True
    #
    status = True
    
    # create an instance of MorphologicalMasker
    #
    masker = MorphologicalMasker(
        thumb_max_size=thumb_max_size,
        objective_power=objective_power,
        baseline_mpp=baseline_mpp,
        kernel_size=kernel_size,
        min_region_size=min_region_size,
    )
    
    # call the run method of the masker and return the masks
    #
    masks = masker.run(fpath)
    
    if masks is None:
        # if masks are None, set the status to False
        #
        status = False
        print(f"{ndt.__FILE__} (line: {ndt.__LINE__}) {ndt.__NAME__}: "
              f"Error: No masks generated for the provided images.")
    
    # exit gracefully
    #
    return status, masks

class TissueMasker(ABC):
    """class: TissueMasker
    description:
      Base class for tissue maskers. Provides methods to fit the masker to
      a set of images and to transform images into tissue masks.
    """
    
    def __init__(self, thumb_max_size, objective_power, baseline_mpp) -> None:
        """
        method: __init__
        arguments:
         thumb_max_size: int, maximum size of the thumbnail image.
         objective_power: float or tuple of floats, objective power(s) used
                          to calculate the target MPP.
         baseline_mpp: float, baseline microns per pixel (MPP) for the slide.
        return: None
        description:
         Initialize the TissueMasker with the maximum thumbnail size, objective
         power, and baseline MPP. The fitted flag is set to False initially.
        """
        
        self.fitted = False
        
        # set the maximum thumbnail size, objective power, and baseline MPP
        # here objective power determines the target MPP which is microns per
        # pixel, and it's important for the tissue segmentation process.
        # The baseline MPP is used as a reference for the segmentation.
        #
        self.thumb_max_size = thumb_max_size
        self.objective_power = objective_power
        self.baseline_mpp = baseline_mpp
    # 
    # end method __init__
    
    @abstractmethod
    def fit(
        self,
        images: np.ndarray,
        masks: np.ndarray | None = None,
    ) -> None:
        """
        method: fit
        arguments:
         images: np.ndarray, array of images to fit the masker to.
         masks: np.ndarray | None, array of masks corresponding to the images.
        return: None
        description:
         Fit the tissue masker to the provided images and masks. This method
         is responsible for determining the threshold values for tissue
         segmentation.
        """
        pass
    # 
    # end method fit
    
    @abstractmethod
    def transform(self, images: np.ndarray) -> np.ndarray:
        """method: transform
        arguments:
         images: np.ndarray, array of images to transform into tissue masks.
        return: np.ndarray, array of tissue masks corresponding to the input
        images.
        description:
         Transform the provided images into tissue masks based on the fitted
         parameters. This method should be called after fit() has been called.
        """
        pass
    #
    # end method transform
        
    def objective_power2mpp(
        self, objective_power: float 
    ) -> float:
        """method: objective_power2mpp
        arguments:
         objective_power: float or list of floats, objective power(s) used to
                          calculate the target MPP.
        return: float, calculated microns per pixel (MPP) value(s)
        description:
            Convert the objective power to microns per pixel (MPP) using a
            predefined conversion factor. The conversion factor is set to 10.0.
            This method can handle both single and multiple objective powers.
        """
        # display an informational message
        #
        if dbgl_d == ndt.FULL:
            print("%s (line: %s) %s: calling objective_power2mpp (%s)" %
                  (__FILE__, ndt.__LINE__, ndt.__NAME__, alg_name))
            
        # check if the objective power is zero or not
        #
        if self.objective_power == 0:
            raise ValueError("Objective power cannot be zero.")
        
        # exit gracefully if the objective power is a single value
        #
        return DEF_MPP_CONVERSION_FACTOR  / np.array(self.objective_power)
    # 
    # end method objective_power2mpp

    def run(
        self,
        fpath: List[str],
    ) -> np.ndarray:
        
        if dbgl_d == ndt.FULL:
            print(f"{ndt.__FILE__} (line: {ndt.__LINE__}) {ndt.__NAME__}: "
                  f"calling run with {len(fpath)} slides.")
        
        # start the timer for reading the thumbnails
        #
        start = time.time()
        
        # read the slides and create thumbnails
        #
        images = [self.slide_thumbnail(slide_path) for slide_path in fpath]
        
        # if debug is set to brief, print the time taken to read slides
        #
        if dbgl_d == ndt.BRIEF:
            print(f"Number of slides read: {len(images)}")
            print(f"Time taken to read slides: {time.time() - start:.2f} seconds")
        
        # start the timer for getting the masks
        #
        start = time.time()
        
        # fit the masker to the images and transform them into masks
        #
        self.fit(images, masks=None)
        
        # transform the images into tissue masks
        #
        masks = self.transform(images)
        
        # if debug is set to brief, print the time taken to generate masks
        #
        if dbgl_d == ndt.BRIEF:
            print(f"Time taken to generate masks: {time.time() - start:.2f} seconds")
        
        # return the generated masks
        # exit gracefully
        #
        return masks
    
    def slide_thumbnail(
        self,
        slide_path: str,
    ) -> np.ndarray:
        """method: slide_thumbnail
        arguments:
         slide_path: str, path to the whole slide image file.
        return: np.ndarray, thumbnail image of the slide.
        description:
         Read the whole slide image and create a thumbnail based on the
         specified objective power and baseline MPP. The thumbnail is resized
         to fit within the maximum size defined by thumb_max_size.
        """
        if dbgl_d == ndt.FULL:
            print(f"{ndt.__FILE__} (line: {ndt.__LINE__}) {ndt.__NAME__}: "
                  f"calling slide_thumbnail with slide_path: {slide_path}")
        
        is_wsi = False
        if Path(slide_path).suffix.lower() in [f'.{nft.DEF_EXT_SVS}']:
            # open the slide using openslide
            #
            slide = openslide.OpenSlide(slide_path)
        
            # get the dimensions of the slide
            #
            slide_width, slide_height = slide.dimensions
            is_wsi = True
        else:
            slide = Image.open(slide_path).convert("RGB")
            # get the dimensions of the image
            #
            slide_width, slide_height = slide.size

        if dbgl_d == ndt.BRIEF:
            print(f"Slide dimensions: {slide_width}x{slide_height}")
        
        # calculate the target microns per pixel (MPP) based on 
        # objective power and baseline MPP
        #
        target_mpp = self.objective_power2mpp(self.objective_power)
        
        # calculate the downsample factor based on the target MPP
        # and baseline MPP
        #
        downsample_factor = target_mpp / self.baseline_mpp
        
        if dbgl_d == ndt.BRIEF:
            print(f"Target MPP: {target_mpp}, Downsample factor: {downsample_factor}")
        
        # calculate the target width and height for the thumbnail
        # based on the slide dimensions and downsample factor
        # 
        target_width = int(slide_width / downsample_factor)
        
        # target height is calculated similarly
        # to ensure the thumbnail fits within the maximum size
        #
        target_height = int(slide_height / downsample_factor)
        
        # check if the target dimensions exceed the maximum size
        # if they do, scale them down proportionally
        #
        if max(target_width, target_height) > self.thumb_max_size:
            if dbgl_d == ndt.BRIEF:
                print(f"Scaling down thumbnail to fit within max size: {self.thumb_max_size}")
            # calculate the scaling factor
            # and apply it to both dimensions
            #
            scale = self.thumb_max_size / max(target_width, target_height)
            if dbgl_d == ndt.BRIEF:
                print(f"Scaling factor: {scale}")
                
            # scale the target width and height
            # to fit within the maximum size
            #
            target_width = int(target_width * scale)
            target_height = int(target_height * scale)
            
        if dbgl_d == ndt.BRIEF:
            print(f"Target thumbnail size: {target_width}x{target_height}")
        
        # get the thumbnail from the slide
        # and convert it to RGB format
        # then convert it to a numpy array
        #
        if is_wsi:
            thumbnail = slide.get_thumbnail((target_width, target_height)).convert("RGB")
        else:
            # if the slide is not a WSI, resize the image
            # to the target dimensions
            #
            thumbnail = slide.resize((target_width, target_height), Image.Resampling.LANCZOS)
        thumbnail = np.array(thumbnail, dtype=np.uint8)
        
        if dbgl_d == ndt.BRIEF:
            print(f"Thumbnail shape: {thumbnail.shape}")
        
               
        # close the slide to free resources
        #
        slide.close()
        
        # exit gracefully
        #
        return thumbnail
    #
    # end method slide_thumbnail

#
# end of class TissueMasker

class OtsuTissueMasker(TissueMasker):
    """class: OtsuTissueMasker
    description:
        Tissue masker using Otsu's method for thresholding. This class inherits
        from TissueMasker and implements the fit and transform methods to
        create tissue masks based on Otsu's thresholding technique.
    """

    def __init__(self, thumb_max_size, objective_power, baseline_mpp) -> None:
        """
        method: __init__
        arguments:
         thumb_max_size: int, maximum size of the thumbnail image.
         objective_power: float or tuple of floats, objective power(s) used
                          to calculate the target MPP.
         baseline_mpp: float, baseline microns per pixel (MPP) for the slide.
        return: None
        description:
         Initialize the OtsuTissueMasker with the maximum thumbnail size,
         objective power, and baseline MPP. The threshold is set to None
         initially, and the fitted flag is set to False.
        """
        # call the parent class constructor
        #
        super().__init__(thumb_max_size=thumb_max_size, objective_power=objective_power, baseline_mpp=baseline_mpp)
        
        # initialize the threshold to None
        # and the fitted flag to False
        #
        self.threshold: float = None
        self.fitted: bool = False
    #
    # end method __init__

    def fit(
        self,
        images: List[np.ndarray],
        masks: np.ndarray | None = None,
    ) -> None:
        """
        method: fit
        arguments:
         images: np.ndarray, array of images to fit the masker to.
         masks: np.ndarray | None, array of masks corresponding to the images.
        return: None
        description:
         Fit the OtsuTissueMasker to the provided images. This method calculates
         the Otsu's threshold based on the pixel values of the images and sets
         the fitted flag to True.
        """
       
        if dbgl_d == ndt.FULL:
            print(f"{ndt.__FILE__} (line: {ndt.__LINE__}) {ndt.__NAME__}: "
                  f"images shape: {images_shape}")
            
        # convert the images to grayscale
        #
        all_pixels = []
        
        for image in images:
            # check if the image is in RGB format, convert to grayscale
            #
            if len(image.shape) == 3 and image.shape[-1] == 3:
                grey = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            else:
                grey = image[..., 0] if len(image.shape) == 3 else image
            
            # flatten and collect pixels
            all_pixels.append(grey.flatten())
        
        # concatenate the pixel values of all grayscale images
        # into a single array for threshold calculation
        #
        pixels = np.concatenate(all_pixels)
        
        # calculate the Otsu's threshold using OpenCV's threshold function
        #
        self.threshold, _  = cv2.threshold(pixels, 
                                           DEF_INIT_OSTU_THRESHOLD, DEF_MAX_BINARY_VALUE, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

        # set the fitted flag to True
        #
        self.fitted = True
        
    # end method fit

    def transform(self, images: List[np.ndarray]) -> np.ndarray:
        """
        method: transform
        arguments:
         images: np.ndarray, array of images to transform into tissue masks.
        return: np.ndarray, array of tissue masks corresponding to the input
        images.
        description:
         Transform the provided images into tissue masks based on the fitted
         Otsu's threshold. This method applies the threshold to each image and
         returns a boolean mask indicating tissue regions.
        """
        if dbgl_d == ndt.FULL:
            print(f"{ndt.__FILE__} (line: {ndt.__LINE__}) {ndt.__NAME__}: "
                  f"calling transform with {len(images)} images.")
        # check if the masker has been fitted
        #
      
        if not self.fitted:
            msg = "Fit must be called before transform."
            raise SyntaxError(msg)
        # initialize an empty list to store masks
        #        
        masks = []
        
        # iterate through the images and apply the Otsu's threshold
        # to create masks
        #
        for image in images:
            # get the first channel of the image
            # if the image is in RGB format, convert it to grayscale
            #
            grey = image[..., 0]
            if len(image.shape) == 3 and image.shape[-1] == 3:  
                grey = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            
            # apply the Otsu's threshold to create a mask
            #
            mask = (grey < self.threshold).astype(bool)
            
            # append the mask to the list of masks
            #
            masks.append(mask)
        # exit gracefully: convert the list of masks to a numpy array and
        # return it
        #
        return np.array(masks)
    #
    # end method transform
#
# end of class OtsuTissueMasker

class MorphologicalMasker(OtsuTissueMasker):
    """class: MorphologicalMasker
    description:
        Tissue masker using Otsu's method followed by morphological operations
        for small region removal and dilation. This class inherits from
        OtsuTissueMasker and implements additional functionality for
        morphological
        operations to refine the tissue masks.
    """

    def __init__(
        self, thumb_max_size, objective_power, baseline_mpp,
        kernel_size: int | tuple[int, int] | np.ndarray | None = None,
        min_region_size: int | None = None,
    ) -> None:
        """
        method: __init__
        arguments:
         thumb_max_size: int, maximum size of the thumbnail image.
         objective_power: float or tuple of floats, objective power(s) used
                          to calculate the target MPP.
         baseline_mpp: float, baseline microns per pixel (MPP) for the slide.
         kernel_size: int or tuple of ints, size of the morphological kernel
                      used for dilation. If None, defaults to (1, 1).
         min_region_size: int, minimum size of regions to keep after
         morphological operations. If None, defaults to the sum of the kernel
         elements.
        return: None
        description:
         Initialize the MorphologicalMasker with the maximum thumbnail size,
         objective power, baseline MPP, kernel size for morphological operations,
         and minimum region size. The kernel is created using OpenCV's
         getStructuringElement function.
        """
        # call the parent class constructor
        #
        super().__init__(thumb_max_size=thumb_max_size, objective_power=objective_power, baseline_mpp=baseline_mpp)

        # initialize the morphological masker with kernel size and minimum
        # region size. If kernel_size is None, it defaults to (1, 1).
        # If min_region_size is None, it defaults to the sum of the kernel
        # elements.
        #
        self.min_region_size = min_region_size
        self.threshold = None
        
        # check if kernel_size is None or not
        #
        if  kernel_size is None:
            kernel_size = np.array([1, 1])
        
        # if kernel_size is a single integer, convert it to a tuple of two
        # integers. If it is a tuple or numpy array, ensure it has two
        # elements.
        #
        kernel_size_array = np.array(kernel_size)
        if kernel_size_array.size != 2:  
            kernel_size_array = kernel_size_array.repeat(2)

        self.kernel_size: tuple[int, int]
        self.kernel_size = tuple(np.round(kernel_size_array).astype(int))
        
        if dbgl_d == ndt.FULL:
            print(f"{ndt.__FILE__} (line: {ndt.__LINE__}) {ndt.__NAME__}: "
                  f"kernel size: {self.kernel_size}")
        # create the morphological kernel using OpenCV's getStructuringElement
        # function with an elliptical shape. The kernel is used for dilation
        # operations to refine the tissue masks.
        # 
        self.kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, self.kernel_size)
        
        # if the minimum region size is not provided, set it to the sum of the
        # kernel elements. This ensures that small regions that are smaller
        # than the kernel size are removed during the morphological operations.
        #
        if self.min_region_size is None:
            self.min_region_size = np.sum(self.kernel)
    #
    # end method __init__

    def transform(self, images: np.ndarray) -> np.ndarray:
        """
        method: transform
        arguments:
         images: np.ndarray, array of images to transform into tissue masks.
        return: np.ndarray, array of tissue masks corresponding to the input
        images.
        description:
         Transform the provided images into tissue masks based on the fitted
         Otsu's threshold and apply morphological operations to refine the masks.
         This method removes small regions and dilates the remaining tissue
         regions to create a more refined mask.
        """
        if dbgl_d == ndt.FULL:
            print(f"{ndt.__FILE__} (line: {ndt.__LINE__}) {ndt.__NAME__}: "
                  f"calling transform with {len(images)} images.")
        
        # check if the masker has been fitted
        #
        if not self.fitted:
            msg = "Fit must be called before transform."
            raise SyntaxError(msg)
        
        # initialize an empty list to store the results
        #
        results = []
        # iterate through the images and apply Otsu's thresholding
        # followed by morphological operations to create masks
        #
        if dbgl_d == ndt.FULL:
            print(f"{ndt.__FILE__} (line: {ndt.__LINE__}) {ndt.__NAME__}: "
                  f"Processing {len(images)} images for tissue masking.")
            
        for image in images:
            # check if the image is in RGB format
            # if it is, convert it to grayscale using OpenCV's cvtColor
            # function
            #
            if len(image.shape) == 3 and image.shape[-1] == 3:  
                gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            else:
                gray = image
            # apply the Otsu's threshold to create a binary mask
            # where pixel values below the threshold are considered tissue
            # and those above are considered non-tissue
            #
            mask = (gray < self.threshold).astype(np.uint8)
            
            # if debug is set to brief, print the shape of the mask
            #
            if dbgl_d == ndt.BRIEF:
                print(f"{ndt.__FILE__} (line: {ndt.__LINE__}) {ndt.__NAME__}: "
                      f"Mask shape: {mask.shape}, Threshold: {self.threshold}")
            
            # apply connected components analysis to remove small regions
            # using OpenCV's connectedComponentsWithStats function. This
            # function labels connected components in the binary mask and
            # provides statistics about each component, including its size.
            # The mask is then updated to remove components smaller than the
            # specified minimum region size.
            #
            _, output, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=DEF_MORF_CONNECTIVITY)
            
            # get the sizes of the connected components from the stats
            # array, which contains the statistics for each component.
            # The first row corresponds to the background, so we skip it
            # and start from the second row. The last column of stats contains
            # the size of each component.
            #
            sizes = stats[1:, -1]
            
            # create a mask to keep only the components that are larger than
            # the minimum region size. The mask is initialized to zero (background)
            # and then set to one for components that meet the size criteria.
            #
            for i, size in enumerate(sizes):
                if size < self.min_region_size:
                    mask[output == i + 1] = 0
                    
            # apply morphological dilation to the mask using the kernel
            # to enhance the tissue regions
            #
            mask = cv2.morphologyEx(mask, cv2.MORPH_DILATE, self.kernel)
            # append the final mask to the results list
            #
            results.append(mask.astype(bool))
            
        # exit gracefully: convert the list of results to a numpy array
        # and return it
        return results
    # 
    # end method transform
    
 
 # 
 # end of class MorphologicalMasker

# end of file

