#!/usr/bin/env python
#
# file: $NEDC_NFC/src/class/nedc_dpath_decode_slide.py
#
# revision history:
#
# 20230929 (SM): prep for v2.0.0 release
# 20211224 (PM): refactored code
# 20210416 (VK): initial release
#
# Usage:
#  import nedc_dpath_decode_slide as ndds
#
# This file contains a class that decode the neural network in a single svs file
#------------------------------------------------------------------------------

# import system modules
#
import os
import sys
import numpy as np
from skimage import io
import torch
import torch.nn as nn
from torchvision import models
import copy
import xml.etree.ElementTree as ET

# import NEDC support modules
#
from nedc_dpath_slide import Slide
import nedc_file_tools as nft

#------------------------------------------------------------------------------
#
# global variables are listed here
#
#------------------------------------------------------------------------------

TRAIN, DEV, EVAL = 'train', 'dev', 'eval'
IMG_EXT = 'tif'
COLORS_FNAME = 'colors.tif'

XML_TEMPLATE = os.path.join(os.path.dirname(__file__), 'template.xml')

LABELS = ['artf', 'bckg', 'dcis','indc',
         'infl', 'nneo','norm', 'null', 'susp']

class DecodeSlide:
    '''
    class: DecodeSlide
    
    description:
     this class contains the methods for decoding a single svs file through the
     trained neural network model
    '''

    def __init__(self, init_model, model_fname, svs_fnames, class_names,
                 transform, batch_size, num_workers, win_len, frm_len, device):
        '''
        method: DecodeSlide::constructor

        arguments:
         init_model: the pretrained model that needs to be trained
         model_fname: the file name of the model that will be loaded
         svs_fnames: a list of svs file names
         class_names: list of class names
         transform: the transform to use
         batch_size: the batch size of dataloader
         num_workers: number of processors used for decoding
         win_len: the window length
         frm_len: the frame length
         device: the device that decoding will be on

        return:
         None
        '''

        # store the input arguments for further use information
        #
        self.svs_fnames = svs_fnames
        self.device = device
        self.transform = transform
        self.batch_size = batch_size
        self.nworkers = num_workers
        self.class_names = class_names
        self.wlen = win_len
        self.flen = frm_len

        # load model
        #
        self.model = init_model
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, len(self.class_names))
        if model_fname is not None:
            mdl = torch.load(model_fname, map_location=self.device)
            self.model.fc = nn.Linear(num_ftrs, mdl['fc.weight'].shape[0])
            self.model.load_state_dict(mdl)
        self.model = self.model.to(device)
    #
    # end of method

    def decode_flist(self):
        '''
        method: DecodeSlide::decode_flist

        arguments:
         None

        return:
         decode_list: a list of probabilities

        description:
         decode the each svs file in the list
        '''
        
        # local list
        #
        decode_list = []

        for fname in self.svs_fnames:

            # read svs file
            #
            dataset = Slide(fname, self.wlen, self.flen, self.transform)

            # dataloader
            #
            dataloader = \
                torch.utils.data.DataLoader(dataset, batch_size=self.batch_size,
                                            shuffle=False,
                                            num_workers=self.nworkers)
            decode_list.append((fname, self.decode(dataloader)))

        # exit gracefully
        #
        return decode_list
    #
    # end of method

    def decode(self, dataloader, verbose=False):
        '''
        method: DecodeSlide::decode

        arguments:
         dataloader: the dataloader object to decode
         verbose: the verbose option for debugging

        return:
         all_probs: probabilities of every event in a patch with corresponding
                    coordinates

        description:
         this function computes all the information needed by decoding
        '''

        # find the number of samples
        #
        nsamples = len(dataloader) * dataloader.batch_size

        # save the model state to restore it at the end
        #
        was_training = self.model.training
        self.model.eval()

        # extract all probabilities
        #
        all_coords_probs = []

        # loop over all epochs to train the neural network
        #
        lcounter = 0
        with torch.no_grad():
            for i, (win_coords, sld_coords, inputs) \
                in enumerate(dataloader):

                inputs = inputs.to(self.device)

                outputs = self.model(inputs)
                outputs_list = nn.functional.softmax(outputs,
                                                     dim=1).cpu().tolist()

                # this convert windows coordinates, slide coordinates, and
                # outputs to the normal list format.
                # coordinates are converted in Numpy format (rows, cols).
                #
                coords_probs = [[int(win_coords[0][c]), int(win_coords[1][c]),
                                 int(sld_coords[1][c]),
                                 int(sld_coords[0][c])] + outputs_list[c]
                                for c in range(len(outputs_list))]

                # store all coordinates and probabilities in a single list
                #
                all_coords_probs += coords_probs
                lcounter += len(outputs_list)
                if verbose:
                    if (lcounter % verbose) == 0:
                        print(f'{lcounter+1}/{nsamples}')

        # change the model state to last state
        #
        self.model.train(mode=was_training)

        # exit gracefully
        #
        return all_coords_probs
    #
    # end of method

    def write_probabilities(self, decode_list, out_dir, classes):
        '''
        method: DecodeSlide::save_probabilities

        arguments:
         decode_list: probabilities of every event in a svs file
         csv_dir: the directory that csv files will be saved
         classes: a list of class names

        return:
         None

        description:
         this function writes all of the prediction data to csv and xml files
        '''

        # check if the path of csv exist.
        #
        os.makedirs(out_dir, exist_ok=True)
                
        # loop over all files
        #
        for (fname, coords_probs) in decode_list:

            # create a directory for the current image
            #
            csv_dir = f"{out_dir}/{os.path.basename(fname).split('.')[0]}"
            os.makedirs(csv_dir, exist_ok=True)

            # write a confidence level csv
            #
            conf_fname = os.path.basename(fname).split('.')[0] + '.csv'
            conf_fname = os.path.join(csv_dir, conf_fname)
            self.write_confidence_csv(conf_fname, coords_probs, classes)

            # write a nedc csv
            #
            nedc_fname = os.path.basename(fname).split('.')[0] + '_nedc.csv'
            nedc_fname = os.path.join(csv_dir, nedc_fname)
            self.write_nedc_csv(nedc_fname, coords_probs, classes)

            # write a nedc xml file
            #
            xml_nedc_fname = nedc_fname.replace('.csv', '.xml')
            info, _, regions, _ = self.csv2list(nedc_fname)
            self.reg2xml(regions, info, xml_nedc_fname, XML_TEMPLATE,
                         LABELS)
    #
    # end of method

    def write_confidence_csv(self, fname, coords_probs, classes):
        '''
        method: DecodeSlide::write_confidence_csv

        arguments:
         fname: the file name that is being decoded
         coords_probs: the coordinates for each probability
         csv_dir: the directory to write the csv files to
         classes: the classes for the predictions

        return:
         None

        description:
         write a confidence level csv file for a single image
        '''

        # make the string of information
        #
        csv_str = ''
        for row in coords_probs:
            for val in row:
                if type(val) == int:
                    csv_str += str(val)
                else:
                    csv_str += f'{val:.6f}'
                csv_str += nft.DELIM_COMMA
            csv_str += classes[np.argmax(row[4:])]
            csv_str += nft.DELIM_NEWLINE

        # write the file
        #
        with open(fname, mode=nft.MODE_WRITE_TEXT+'+') as csv:
            csv.writelines(csv_str)
    #
    # end of method

    def write_nedc_csv(self, fname, coords_probs, classes):
        '''
        method: DecodeSlide::write_nedc_csv

        arguments:
         fname: the file name that is being decoded
         coords_probs: the coordinates for each probability
         csv_dir: the directory to write the csv files to
         classes: the classes for the predictions

        return:
         None

        description:
         write a nedc csv file for a single image
        '''

        # calculate the coordinates of the current image
        #
        width, height = (coords_probs[-1][3] + self.wlen,
                        coords_probs[-1][2] + self.wlen)

        # make the string of information
        #
        csv_str = ''
        csv_str += "# version = csv_v1.0.0\n"
        csv_str += "# MicronsPerPixel = 0.0000\n"
        csv_str += "# bname = " + os.path.basename(fname)[:-4] + '\n'
        csv_str +=f"# width = {width} pixels, height = {height} pixels\n"
        csv_str += "# tissue = breast\n"
        csv_str += "#\n"
        csv_str += "index,region_id,tissue,label,coord_index," + \
            "row,column,depth,confidence\n"
        
        # loop over each row
        #
        for rowc, row in enumerate(coords_probs):
            coord = row[2:4]
            label = classes[np.argmax(row[4:])]
            for cc, c in enumerate(
                    [
                        (coord[0], coord[1]),
                        (coord[0], coord[1]+self.wlen),
                        (coord[0]+self.wlen, coord[1]+self.wlen),
                        (coord[0]+self.wlen, coord[1]),
                        (coord[0], coord[1])
                    ]):
                csv_str += str(rowc)
                csv_str += nft.DELIM_COMMA
                csv_str += str(rowc+1)
                csv_str += nft.DELIM_COMMA
                csv_str += 'breast'
                csv_str += nft.DELIM_COMMA
                csv_str += label
                csv_str += nft.DELIM_COMMA
                csv_str += str(cc)
                csv_str += nft.DELIM_COMMA
                csv_str += str(c[0])
                csv_str += nft.DELIM_COMMA
                csv_str += str(c[1])
                csv_str += nft.DELIM_COMMA
                csv_str += '0'
                csv_str += nft.DELIM_COMMA
                csv_str += '1.0'                
                csv_str += nft.DELIM_NEWLINE

        # write the file
        #
        with open(fname, mode=nft.MODE_WRITE_TEXT+'+') as csv:
            csv.writelines(csv_str)
    #
    # end of method

    def csv2list(self, csv_fname, comment_char = '#',
                 seperator=',', header=True):
        '''
        method: DecodeSlide::csv2list

        arguments:
         csv_fname: nedc csv file name

        return:
         info: a dictionary of the file header
         header: name of columns in the nedc csv file
         regions: a dictionary which labels as keys and list of coordinates as
                  values
         matrix: a list of extracted information from the csv file

        description:
         returns a dictionary with label keys and list of coordinates values.
         each value is a list of boundaries. Each boundary is a list of
         coordinates
        '''
        
        regions = dict()
        matrix = []
        with open(csv_fname, mode='r') as cf:
            csv = cf.readlines()

        # read basename, width, height, tissue
        #
        info = dict()
        lcounter = 0
        line = csv[lcounter]
        while line[0] == '#':
            if 'bname' in line:
                info['bname'] = line[1:].split('=')[1].strip()

            elif 'width' in line:
                info['width'] = int(line[1:].split(',')[0].split(
                    '=')[1].replace('pixels', '').strip())
                info['height'] = int(line[1:].split(',')[1].split(
                    '=')[1].replace('pixels', '').strip())
                
            elif 'tissue' in line:
                info['tissue'] = line[1:].split('=')[1].strip()

            lcounter += 1
            line = csv[lcounter]

        # remove comment lines
        #
        csv = [l for l in csv if l[0] != comment_char]

        # read and remove header
        #
        if header:
            header = [h.strip() for h in csv[0].split(seperator)]
            csv = csv[1:]

        for line in csv:
            row = [l.strip() for l in line.split(seperator)]
            region_index = int(row[0])
            region_id = int(row[1])
            tissue = row[2]
            label = row[3]
            coord_index = int(row[4])
            row_coord, col_coord, dep_coord = \
                (int(row[5]), int(row[6]), int(row[7]))
            confidence = float(row[8])
            matrix.append([region_index, region_id, label, coord_index,
                           row_coord, col_coord, dep_coord, confidence])
            
            if not(label in regions.keys()):
                regions[label] = []

            if coord_index == 0:
                regions[label].append({'Id': region_index, 'coords': []})

            regions[label][-1]['coords'].append((row_coord,
                                                 col_coord,
                                                 dep_coord))
        # exit gracefully
        #
        return info, header, regions, matrix
    #
    # end of method

    def reg2xml(self, regions, info, xml_fname, xml_template,
                interesting_labels):
        '''
        method: DecodeSlide::reg2xml

        arguments:
         info: a dictionary of the file header
         regions: a dictionary which labels as keys and list of coordinates as
                  values
         xml_fname: the xml file name
         xml_template: the xml template to base xml files off of
         interesting_labels: all of the available labels

        return:
         None

        description:
         writes an nedc xml file with the prediction probabilities
        '''
        
        # read the nedc XML template
        #
        tree = ET.parse(xml_template)
    
        # get the XML root
        #
        root1 = tree.getroot()

        # append regions' tags to ImageScope XML
        #
        root2 = copy.deepcopy(root1)

        # set width
        #
        root2[0][0][0].set('Value', str(info['width']))

        # set height
        #
        root2[0][0][1].set('Value', str(info['height']))

        # get the Regions section of the XML
        #
        xml_regions = root2[0][1]

        # get a Region from Regions to replicate based on the needs
        # and remove the dummy Region from the template
        #
        xml_vertex = copy.deepcopy(xml_regions[1][1][0])
        xml_regions[1][1].remove(xml_regions[1][1][0])
        xml_region = copy.deepcopy(xml_regions[1])
        xml_regions.remove(xml_regions[1])

        # add the required regions and vertices
        #
        for label in regions:

            if not(label in interesting_labels):
                continue

            for regc, reg in enumerate(regions[label]):
                region = copy.deepcopy(xml_region)
                region.set('Id', str(reg['Id']))
                region.set('Text', label)
                vertices = region[1]
                for coordc, coord in enumerate(reg['coords']):

                    # vertex = ET.SubElement(xml_vertex, 'Vertex')
                    vertex = copy.deepcopy(xml_vertex)
                    vertex.set('Y', str(coord[0]))
                    vertex.set('X', str(coord[1]))
                    vertex.set('Z', str(coord[2]))

                    # correct the indentation
                    #
                    vertex.tail = vertices.text
                    vertices.append(vertex)

                # correct the indentations
                #
                if len(vertices) > 0:
                    vertices[-1].tail = vertices[-1].tail[:-1]
                xml_regions.append(region)
                first_tail = xml_regions[0].tail

                for child in xml_regions[:-1]:
                    child.tail = first_tail

        tree._setroot(root2)
        tree.write(xml_fname)

    def save_masks(self, decode_list, mask_dir, classes,
                   colors, pixel_size=(1, 1)):
        '''
        method: DecodeSlide::save_masks

        arguments:
         decode_list: the probabilities of every event in a svs file
         mask_dir: the directory that csv files will be saved
         classes: the class names
         colors: a dictionary for (class, color) pairs
         pixel_size: the tuple of pixel size used for making the mask

        return:
         None

        description:
         this function computes all the information needed by decoding
        '''

        # check if the path of mask_dir exist.
        #
        os.makedirs(mask_dir, exist_ok=True)

        # save the colors
        #
        colors_fname = os.path.join(mask_dir, COLORS_FNAME)
        colors_picture = np.zeros(((len(classes)+1)*int(pixel_size[0]),
                                   int(pixel_size[1]), 3),
                                  dtype=np.uint8)

        for cc, cls in enumerate(classes):
            slc = slice(int(cc*int(pixel_size[0])), int((cc+1)*int(pixel_size[0])))
            colors_picture[slc, :, 0] = colors[cls][0]
            colors_picture[slc, :, 1] = colors[cls][1]
            colors_picture[slc, :, 2] = colors[cls][2]
        io.imsave(colors_fname, colors_picture)
        
        # loop over all files
        #
        for (fname, coords_probs) in decode_list:

            # make the csv file name and path
            #
            mask_fname = os.path.basename(fname).split('.')[0] + '.' + IMG_EXT
            mask_fname = os.path.join(mask_dir, mask_fname)

            # make the mask
            #
            coords_probs = np.array(coords_probs)
            nrows, ncols = (int(np.max(coords_probs[:, 0]) * pixel_size[0]),
                            int(np.max(coords_probs[:, 1]) * pixel_size[1]))
            mask = np.zeros((nrows+1, ncols+1, 3), dtype=np.uint8)

            # paint the mask
            #
            for rowc, row in enumerate(coords_probs):
                color = np.array(colors[classes[np.argmax(row[4:])]])
                slc = (slice(int(row[0]*pixel_size[0]),
                             int((row[0]+1)*pixel_size[0])),
                       slice(int(row[1]*pixel_size[1]),
                             int((row[1]+1)*pixel_size[1])))
                mask[slc[0], slc[1], 0] = color[0]
                mask[slc[0], slc[1], 1] = color[1]
                mask[slc[0], slc[1], 2] = color[2]
                         
            # save the mask
            #
            io.imsave(mask_fname, mask)
    #
    # end of method

#
# end of class

#
# end of file
