import os
import time
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from nedc_image_tools import Nil
import sys


# ================================================================
# Gaussian Blur (1D separable)
# ================================================================

def gaussian_blur_1D(sigma, radius, device='cpu'):
    size = 2 * radius + 1
    x = torch.arange(-radius, radius + 1, dtype=torch.float32, device=device)
    kernel = torch.exp(-(x ** 2) / (2 * sigma ** 2))
    kernel /= kernel.sum()
    return kernel

def apply_gaussian_blur_batch(imgs, sigma, radius):
    """
    imgs: (B, C, H, W)
    """
    device = imgs.device
    B, C, H, W = imgs.shape

    # 1D kernel
    #
    kernel = gaussian_blur_1D(sigma, radius, device=device)

    # Horizontal
    #
    kernel_h = kernel.view(1, 1, 1, -1).repeat(C, 1, 1, 1)
    blurred_h = F.conv2d(imgs, kernel_h, padding=(0, radius), groups=C)

    # Vertical
    #
    kernel_v = kernel.view(1, 1, -1, 1).repeat(C, 1, 1, 1)
    blurred = F.conv2d(blurred_h, kernel_v, padding=(radius, 0), groups=C)

    return blurred

# ================================================================
# FFT Processing
# ================================================================

def compute_log_fft_batch(img_batch):
    """
    img_batch: (B, C, H, W)
    Returns: (B, H, W)
    """
    fft = torch.fft.fft2(img_batch)
    fft_shift = torch.fft.fftshift(fft, dim=(-2, -1))
    mag2 = fft_shift.real ** 2 + fft_shift.imag ** 2
    log_mag = 10 * torch.log10(mag2 + 1e-12)
    return log_mag.mean(dim=1)

# ================================================================
# Histogram
# ================================================================

def compute_histogram_batch(log_mag_batch, bins=16):
    """
    log_mag_batch: (B, H, W)
    """
    B = log_mag_batch.shape[0]
    mn = log_mag_batch.min()
    mx = log_mag_batch.max()

    hists = []
    for i in range(B):
        h = torch.histc(log_mag_batch[i].flatten(), bins=bins, min=mn, max=mx)
        hists.append(h)

    return torch.stack(hists)

# ================================================================
# Patch Processing
# ================================================================

def process_patches_batch(patches, sigma, radius, bins, device):
    patch_tensors = []
    for patch in patches:
        patch_tensors.append(transforms.ToTensor()(patch))

    batch = torch.stack(patch_tensors).to(device)

    blurred = apply_gaussian_blur_batch(batch, sigma, radius)
    log_mag = compute_log_fft_batch(blurred)
    # Cleanning up making sure there aren't any trash values
    #
    log_mag = torch.nan_to_num(log_mag, nan=0.0, posinf=0.0, neginf=0.0)
    
    hists = compute_histogram_batch(log_mag, bins=bins)
    return hists.cpu()

# ================================================================
# File Discovery
# ================================================================

def find_svs_files(root_dir, max_files=None):
    count = 0
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for f in filenames:
            if f.lower().endswith(".svs"):
                yield os.path.join(dirpath, f)
                count += 1
                if max_files is not None and count >= max_files:
                    return
# ================================================================
# Main
# ================================================================

def main():

    if len(sys.argv) > 1:
        try:
            max_files = int(sys.argv[1])
        except ValueError:
            print("Usage: python3 myaverager.py [max_files]")
            return
    else:
        max_files = None  # process all files


    sigma = 1.0
    radius = 3
    bins = 16
    tile_size = 256
    batch_size = 64

    # Dataset root
    #
    root_dir = "/data/isip/data/tuh_dpath_breast/deidentified/v3.0.0/svs"

    # ===== GPU detection =====
    #
    num_gpus = torch.cuda.device_count()
    if num_gpus == 0:
        print("ERROR: No GPUs available!")
        return

    gpu_ids = list(range(num_gpus))
    print(f"Using GPUs: {gpu_ids}")

    # ===== Find image files automatically =====
    #
    print("Walking directory tree to find .svs files ...")
    files = list(find_svs_files(root_dir, max_files=max_files))
    print(f"Found {len(files)} .svs files.")

    # ===== Timing start =====
    #
    start_time = time.time()

    all_hists = []

    for fname in files:
        print(f"\nProcessing {fname} ...")
        img = Nil()

        if not img.open(fname):
            print(f"Could not open: {fname}")
            continue

        width, height = img.get_dimension()

        # Build list of tile coordinates
        #
        coords = [(x, y) for x in range(0, width, tile_size)
                         for y in range(0, height, tile_size)]

        print(f"  Size: {width}x{height}, tiles: {len(coords)}")

        #Read tiles
        #
        patches = img.read_data_multithread(coords, npixx=tile_size, npixy=tile_size)
        num_patches = len(patches)

        # GPU batching
        #
        for batch_start in range(0, num_patches, batch_size * num_gpus):

            batch_outputs = []

            for gpu_idx in range(num_gpus):
                s = batch_start + gpu_idx * batch_size
                e = min(s + batch_size, num_patches)

                if s >= num_patches:
                    break

                batch_patches = patches[s:e]
                device = f"cuda:{gpu_ids[gpu_idx]}"

                h = process_patches_batch(batch_patches, sigma, radius, bins, device)
                batch_outputs.append(h)

                print(f"  GPU {gpu_idx}: processed {e}/{num_patches}", end='\r')

            if batch_outputs:
                all_hists.append(torch.cat(batch_outputs, dim=0))

        img.close()

    # ===== Combine results =====
    #
    all_hist_tensor = torch.cat(all_hists, dim=0)

    hist_avg = torch.mean(all_hist_tensor, dim=0)
    hist_std = torch.std(all_hist_tensor, dim=0)

    elapsed = time.time() - start_time

    # ===== Nice output =====
    #
    print("\n===== FINAL SUMMARY =====")
    print(f"Total patches processed: {all_hist_tensor.shape[0]}")
    print("Histogram mean per bin:")
    print(hist_avg)
    print("Histogram std per bin:")
    print(hist_std)
    print(f"Total time: {elapsed:.2f} sec")
    print(f"Patches/sec: {all_hist_tensor.shape[0] / elapsed:.2f}")

# Begin Gracefully
if __name__ == "__main__":
    main()
