#!/usr/bin/env python
#
# file: $NEDC_NFC/src/class/nedc_dpath_extract_patch.py
#
# revision history:
#
# 20230929 (SM): prep for v2.0.0 release
# 20211224 (PM): refactored code
# 20210325 (VK): initial release
#
# Usage:
#  import nedc_extract_patch as nep
#
# This file contains a class that does several jobs to extract patches
#------------------------------------------------------------------------------

# import system modules
#
import os
import sys
import time
import numpy as np
import openslide
import xml.etree.ElementTree as ET
from PIL import Image
from skimage.draw import polygon
from skimage.transform import rescale

# import NEDC modules
#
import nedc_file_tools as nft
import nedc_ann_dpath_tools

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

Image.MAX_IMAGE_PIXELS = None
IMG_EXT = '.tif'
TRAIN, DEV, EVAL = 'train', 'dev', 'eval'
LOG_FILE_NAME = "log.txt"
LABEL_PATCH_FNAME = "patches.csv"

#------------------------------------------------------------------------------
#
# classes listed here
#
#------------------------------------------------------------------------------

class ExtractPatch:
    """
    class: ExtractPatch

    description:
     This class contains methods for extract patches from DPath images. It reads 
     level 0, extracts windows, and saves them.
    """

    def __init__(self, params, flist):
        """
        method: ExtractPatch::constructor
        
        arguments:
         params: parameters for extracting the patches
         flist: list of svs files
        
        return: 
         none
        
        description:
         none
        """
        self.params = params
        self.acc_lbls = self.params['acc_lbls']
        self.flists = flist
        self.svs, self.xml = self.svs_xml(flist)

        # exit gracefully
        #
        return None
    #
    # end of method

    def svs_xml(self, flist):
        """
        method: ExtractPatch::svs_xml

        arguments:
         flists: list of svs files

        return:
         svs_fnames: list of svs file names
         xml_fnames: list of xml file names

        description:
         this function finds the corresponding (svs, xml) file pairs
        """

        # extract svs file names
        #
        svs_fnames = []
        with open(flist, mode='r') as file:
            fnames = file.readlines()
            fnames = [os.path.realpath(fn.strip())
                        for fn in fnames if fn.strip() != '']
            svs_fnames += fnames

        # extract xml file names
        #
        all_xml_fnames = [fname.replace('.svs', '.xml') for fname in svs_fnames]
        xml_fnames = []

        # add only the xml files that are exist
        #
        for fname in all_xml_fnames:
            if os.path.isfile(fname):
                xml_fnames.append(fname)
            else:
                print(f'{fname} does not exist.')
        
        # exit gracefully
        #
        return svs_fnames, xml_fnames
    #
    # end of method

    def extract_patch(self, patch_dir, acc_lbls):
        """
        method: ExtractPatch::extract_patch

        arguments:
         patch_dir: the subdirectory of the output directory in which the
                    patches will be written
         acc_lbls: a list of acceptable prediction labels

        return: 
         flps: a state dictionary for the patches
         plist: dictionary of patches

        description:
         extract patches for every given file into the specified directory
        """

        # initialize a file counter
        #
        fcounter = 0

        # retrieve parameters
        #
        win_len, frm_len, lv1_rs = (int(self.params['win_len']),
                                    int(self.params['frm_len']),
                                    int(self.params['lv1_rescale']))
                                    
        # Initialize file processing and patches lists
        #
        flists_ps = dict()
        patch_list = dict()

        # extract patch for every given file
        #
        for xml_fname, svs_fname in zip(self.xml, self.svs):

            # count the amount of files
            #
            fcounter += 1

            # attempt to extract the patch
            #
            try:

                # generate regions, slide, and prefixes
                #
                regions = self.xml2dict(xml_fname)
                slide = openslide.open_slide(svs_fname)
                prefix = os.path.splitext(os.path.basename(xml_fname))[0]

                # create a directory for this patch
                #
                fpath = os.path.join(patch_dir, prefix)
                os.makedirs(fpath, exist_ok=True)
                                
                # extract the patch
                #
                self.save_window(flists_ps, patch_list, prefix, slide, regions,
                                 acc_lbls, fpath, win_len, frm_len, lv1_rs)

            except Exception as e:
                print(e)

        # sort the patch list
        #
        for prefix in patch_list:
            patch_list[prefix].sort()

        # sort and complete the file processing and patch lists
        #
        for key in patch_list:
            patch_list[key].sort()
        plist = {key: patch_list[key] for key in patch_list}
        flps = {key: flists_ps[key] for key in flists_ps}
        
        # exit gracefully
        #
        return (flps, plist)
    #
    # end of method

    def save_window(self, flps, plist, fname_prefix, slide,
                    regions, acc_labels, wdir, wlen, flen, rs):
        """
        method: ExtractPatch::save_window
        
        arguments:
         flps: a list of (file name, process time) pairs to log at the end
         plist: a list of (label, patch name) pairs to log at the end
         fname_prefix: output file names prefix
         slide: openslide slide object
         regions: regions dictionary with label keys and list of coordinates
         acc_labels: a list of acceptable prediction labels
         wdir: the output windows directory
         wlen: window length
         flen: frame length
         rescale: rescaling factor, for level zero this value must be one
        
        return: 
         True
        
        description:
         this function does the main job. It reads level 0, extracts windows and
         saves them
        """

        # initialize every prefix key to an empty list object
        #
        plist[fname_prefix] = list()

        # make separate directory for every acceptable label
        #
        for label in acc_labels:
            label_dir = os.path.join(wdir, label)
            os.makedirs(label_dir, exist_ok=True)

        # find acceptable labels in regions
        #
        acc_regions = {key: regions[key]
                       for key in regions if key in acc_labels}

        # start processing
        #
        time0 = time.time()
        bmask = []
        for label in acc_regions:

            # check if the label directory exists
            #
            label_dir = os.path.join(wdir, label)

            # for each coord in each region
            #  
            for coords3d in acc_regions[label]:

                # initialize coordinates
                #
                coords = np.array(coords3d)[:, 0:2]

                # read the region
                #
                bbox = self.boundary_box(coords)
                bmask = self.bin_mask(coords)

                # windows upper-left coordinates
                #
                wcoords = self.win_coords(bmask, wlen, flen)

                # if wcoords is empty, it means the window size is larger than
                # the region, but all the region will be assumed as label.
                #
                # read the region
                #
                bbox = self.boundary_box(coords)
                lv_box = slide.read_region((bbox[1], bbox[0]), 0,
                                           (bbox[3]-bbox[1]+wlen,
                                            bbox[2]-bbox[0]+wlen))
                
                # saving each window
                #
                for coord in wcoords:

                    # crop and resize the patch
                    #
                    win_pil = lv_box.crop((coord[1], coord[0],
                                           coord[1]+wlen, coord[0]+wlen))
                    win_pil = win_pil.resize((win_pil.width//rs,
                                              win_pil.height//rs))
                    
                    # create the basename of the file. the format is:
                    # basename_row_column
                    #
                    fbase = \
                f'{fname_prefix}_{coord[1]+bbox[1]:05d}_{coord[0]+bbox[0]:05d}'
                    
                    # generate the entire fname path
                    #
                    fpath = label_dir
                    fname = os.path.join(fpath, fbase + IMG_EXT)
                    
                    # save the patch data to the file
                    #
                    win_pil.save(fname, optimize=False)
                    plist[fname_prefix] += [(label, fname)]

        # save the overall duration for saving the windows
        #  
        flps[fname_prefix] = time.time() - time0
        
        # exit gracefully
        #
        return True
    #
    # end of method

    def xml2dict(self, xml_fname, csv_fname=None):
        """
        method: ExtractPatch::xml2dict

        arguments:
         xml_fname: xml file name
         csv_fname: csv file name (default=None)

        return:
         regions: a dictionary which labels as keys and list of coordinates as
                  values

        description:
         returns a dictionary with label and keys and list of coordinate values.
         each alues is a list of boundaries. each boundary is a list of
         coordinates. in xml annotation files, the 'x' is column and the 'y' is
         row
        """

        # converts xml annotations to csv coordinates region, X, Y, Z
        #
        reg = nedc_ann_dpath_tools.read(xml_fname)
        regions = dict()

        # If the regions labels are known beforehand, they can be definded here
        #     {
        #         'Normal': [],
        #         'Benign': [],
        #         'Carcinoma in situ': [],
        #         'Invasive carcinoma': [],
        #         'Carcinoma invasive': [],
        #         'In situ carcinoma': []
        #     }
        #
        for regid in reg[1]:

            # set the region
            #
            region = reg[1][regid]

            # get the annotation for the current region
            #
            label = region.get('text')

            # if the annotation label does not exist in the region, initialize
            # an empty array for that label
            #
            if not(label in regions.keys()):
                regions[label] = []

            # add the coordinates to the label key for the region
            #
            regions[label].append([[int(Y), int(X), int(Z)]
                                   for X, Y, Z in region.get('coordinates')])

        # if the csv file exists
        #
        if csv_fname is not None:

            # open the csv file
            #
            with open(csv_fname, mode='w+') as file:

                # write the region data in csv format
                #
                for rc, region in enumerate(regions):
                    file.writelines([str(rc)+','+str(v[0])+','+str(v[1])+',' +
                                     str(v[2])+'\n' for v in region])
                    
        # exit gracefully
        #
        return regions
    #
    # end of method

    def boundary_box(self, coords):
        """
        method: ExtractPatch::boundary_box

        arguments:
         coords: a 2 column Numpy array (x, y)

        return:
         [xmin, ymin, xmax, ymax]: rectangular coordinates of the boundary box

        description:
         this function returns the bounding box of an area which is just a
         combination of boundry pixel coordinates
        """
        
        # change the coords array to a Numpy array object
        #
        coords = np.array(coords)

        # exit gracefully
        #
        return(np.array([coords[:, 0].min(), coords[:, 1].min(),
                         coords[:, 0].max(), coords[:, 1].max()]))
    #
    # end of method

    def bin_mask(self, coords):
        """
        method: ExtractPatch::bin_mask

        arguments:
         coords: a 2 column Numpy array (x, y)

        return:
         bmask: a generated dictionary binary mask

        description:
         this function returns the binary mask of the area from the boundary
         coordinates
        """

        # change the coords array to a Numpy array object
        #
        coords = np.array(coords)
        
        # generate the boundary box
        #
        bbox = self.boundary_box(coords)
        
        # generate the binary mask
        #
        coords -= bbox[:2]
        rows, cols = polygon(coords[:, 0], coords[:, 1])
        bmask = np.zeros((coords[:, 0].max()+1, coords[:, 1].max()+1),
                         dtype=np.uint8)
        bmask[rows, cols] = 1

        # exit gracefully
        #
        return bmask
    #
    # end of method

    def win_coords(self, bmask, wlen, flen):
        """
        method: ExtractPatch::win_coords

        arguments:
         bmask: a binary mask from ExtractPatch::bin_mask
         wlens: window lengths
         flen: frame length

        return:
         wcoords; a list of windows coordinates (Numpy Arrays)

        description:
         this functions generates window coordinates
        """

        # overlap
        #
        overlap = flen / wlen

        # check if at least one window can be extracted from the bmask
        #
        if (bmask.shape[0] < wlen) & (bmask.shape[1] < wlen):
            wcoords = np.array([[0, 0]])

        elif bmask.shape[0] < wlen:

            # just extract horizontal windows
            #
            col_coords = np.arange(0, bmask.shape[1]-wlen, flen)
            row_coords = np.zeros(col_coords.size, dtype=np.int)
            wcoords = np.vstack((row_coords, col_coords)).T

        elif bmask.shape[1] < wlen:

            # just extract vertical windows
            #
            row_coords = np.arange(0, bmask.shape[0]-wlen, flen)
            col_coords = np.zeros(row_coords.size, dtype=np.int)
            wcoords = np.vstack((row_coords, col_coords)).T

        else:

            # both horizontal and vertical windows
            #
            bm_small = (rescale((bmask == 1), 1/flen,
                                anti_aliasing=False,
                                channel_axis=0) > overlap)[:-1, :-1]

            col_coords = np.arange(0, bm_small.shape[1]).\
                reshape((1, bm_small.shape[1])).\
                repeat(bm_small.shape[0], axis=0) * flen

            row_coords = np.arange(0, bm_small.shape[0]).\
                reshape((bm_small.shape[0], 1)).\
                repeat(bm_small.shape[1], axis=1) * flen

            all_coords = np.concatenate(
                (row_coords.reshape((row_coords.size, 1)),
                 col_coords.reshape((col_coords.size, 1))),
                axis=1)

            bm_small_2d = np.concatenate((bm_small.reshape(bm_small.size, 1),
                                          bm_small.reshape(bm_small.size, 1)),
                                         axis=1)

            wcoords = all_coords[bm_small_2d == 1]
            wcoords = wcoords.reshape((wcoords.size//2, 2))

        # exit gracefully
        #
        return(wcoords)
    #
    # end of method

    def save_state_dict(self, state, log_dir):
        """
        method: ExtractPatch::save_flists

        arguments:
         state: the state dictionary of processes files
         log_dir: the output log directory

        return:
         None

        description:
         this function saves the state dictionary
        """

        # save the train, dev, and eval files lists based on state_dict
        #
        flist = list()
        for fname in self.xml:
            prefix = os.path.splitext(os.path.basename(fname))[0]
            if state.get(prefix) is not None:
                flist.append(f'{fname},{state[prefix]:0.0f}\n')

        # save the file list
        #
        flist_name = os.path.join(log_dir, LOG_FILE_NAME)
        with open(flist_name, mode=nft.MODE_WRITE_TEXT+'+') as fl:
            fl.writelines(flist)

        # exit gracefully
        #
        return flist
    #
    # end of method

    def save_labels_patches(self, patches_list, log_dir):
        """
        method: ExtractPatch::save_labels_patches

        arguments:
         patches_list: a dictionary of patches
         log_dir: the output log directory
         
        return:
         None

        description:
         this function saves a csv with two columns, label and patches file path
        """
        
        # create an empty array for each acceptable label
        #
        csv = {label: [] for label in self.acc_lbls}
        
        # for each patch in patches_list
        #
        for fname in patches_list:
         
            # for label and patch in the patch list
            #
            for label, pfname in patches_list[fname]:
                
                # if the label is the csv array, add the file to the csv dict
                #
                if label in csv:
                    csv[label].append(pfname)
                    
        # sort the filenames
        #
        for label in csv:
            csv[label] = sorted(csv[label])

        # save the files and their labels to the log csv file
        #
        flist_name = os.path.join(log_dir, LABEL_PATCH_FNAME)
        with open(flist_name, mode='w') as fl:
            for label in csv:
                fl.writelines([f"{label},{pfname}\n" for pfname in csv[label]])
        
        # exit gracefully
        #
        return csv
    #
    # end of method
    
#
# end of class

#
# end of file