#!/usr/bin/env python

import nedc_image_tools
import pathlib
import numpy as np
import torch
import ctypes
import sys
import time
import multiprocessing as mp
from queue import Empty

# ---------------- CONFIG ----------------
FRAME_SIZE = (128, 128)
WINDOW_SIZE = (256, 256)
BATCH_SIZE = 128
NUM_THREADS_PER_PROCESS = 12 
# ---------------------------------------

GAUSSIAN_FILTER_SIZE = 7
GAUSSIAN_FILTER_SIGMA = 1.0
HISTOGRAM_BINS = 16
LOG_MIN = -5.0
LOG_MAX = 10.0

try:
    gpu = ctypes.CDLL('./gpu.so')
except OSError:
    print("ERROR: Could not load './gpu.so'. Ensure it is compiled.", file=sys.stderr)
    sys.exit(1)

gpu.set_device.argtypes = [ctypes.c_int]
gpu.gaussian_smoother.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
gpu.create_gaussian_filter.argtypes = [ctypes.c_int, ctypes.c_float]
gpu.log_mag_merge_histogram.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int]
gpu.reset_global_histogram.argtypes = None
gpu.get_global_histogram.argtypes = [ctypes.c_void_p, ctypes.c_int]

def generateTopLeftFrameCoordinates(height: int, width: int, frame_size: tuple):
    coords = []
    for x in range(0, width, frame_size[0]):
        for y in range(0, height, frame_size[1]):
            coords.append((x, y))
    return coords

def windowRGBValues_generator(image_file: str):
    """Generates NumPy batches using nedc_image_tools."""
    image_reader = nedc_image_tools.Nil()
    image_reader.open(image_file)
    width, height = image_reader.get_dimension()
    frame_coords = generateTopLeftFrameCoordinates(height, width, FRAME_SIZE)

    for i in range(0, len(frame_coords), BATCH_SIZE):
        batch_coords = frame_coords[i:i + BATCH_SIZE]
        batch = image_reader.read_data_multithread(
            batch_coords,
            npixy=WINDOW_SIZE[1],
            npixx=WINDOW_SIZE[0],
            color_mode="RGB",
            num_threads=NUM_THREADS_PER_PROCESS
        )
        yield np.array(batch, dtype=np.float32)

def calculate_stats_from_histogram(histogram: np.ndarray):
    bin_width = (LOG_MAX - LOG_MIN) / HISTOGRAM_BINS
    total_count = np.sum(histogram)
    if total_count == 0: return 0.0, 0.0
    centers = LOG_MIN + (np.arange(HISTOGRAM_BINS) + 0.5) * bin_width
    mean = np.sum(centers * histogram) / total_count
    diff = centers - mean
    var = np.sum(histogram * diff * diff) / total_count
    return mean, np.sqrt(var)


def gpu_worker(gpu_id: int, job_queue: mp.Queue, hist_queue: mp.Queue):

    device = torch.device(f"cuda:{gpu_id}")
    print(f"[GPU Worker {gpu_id}] started on {device}", flush=True)

    gpu.set_device(gpu_id)
    gpu.create_gaussian_filter(GAUSSIAN_FILTER_SIZE, GAUSSIAN_FILTER_SIGMA)
    gpu.reset_global_histogram()
    torch.cuda.synchronize()
    
    pinned_buffer = torch.empty(
        (BATCH_SIZE, WINDOW_SIZE[1], WINDOW_SIZE[0], 3),
        dtype=torch.float32, 
        pin_memory=True, 
        device='cpu'
    )
    pinned_np_view = pinned_buffer.numpy()

    batches_processed = 0
    start_time = time.time()
    
    while True:
        try:
            image_file = job_queue.get(timeout=1)
        except Empty:
            break
            
        print(f"[GPU Worker {gpu_id}] processing {pathlib.Path(image_file).name}", flush=True)

        for batch_np in windowRGBValues_generator(str(image_file)):
            N_batch, H, W, C_dim = batch_np.shape
            pinned_np_view[:N_batch] = batch_np 
            
            tensor_batch = pinned_buffer[:N_batch].to(device, non_blocking=True).contiguous() 
            temp = torch.empty_like(tensor_batch, device=device)

            gpu.gaussian_smoother(ctypes.c_void_p(tensor_batch.data_ptr()), ctypes.c_void_p(temp.data_ptr()), N_batch, H, W, GAUSSIAN_FILTER_SIZE)
            torch.cuda.synchronize(device=device)

            batch_nchw = tensor_batch.permute(0, 3, 1, 2).contiguous() 
            fft_img = torch.fft.fft2(batch_nchw, dim=(-2, -1))

            N_fft = N_batch * C_dim
            gpu.log_mag_merge_histogram(ctypes.c_void_p(fft_img.data_ptr()), N_fft, H, W)
            torch.cuda.synchronize(device=device)
            
            del temp, tensor_batch, batch_nchw, fft_img
            batches_processed += 1
            
    end_time = time.time()
    
    host_hist = np.zeros(HISTOGRAM_BINS, dtype=np.uint32)
    host_hist_ptr = host_hist.ctypes.data_as(ctypes.POINTER(ctypes.c_uint))
    
    gpu.get_global_histogram(host_hist_ptr, 0)
    torch.cuda.synchronize(device=device)
    
    hist_queue.put(host_hist)
    
    print(f"[GPU Worker {gpu_id}] finished. Processed {batches_processed} batches in {end_time - start_time:.2f}s.", flush=True)

def main():
    print("PROGRAM START", flush=True)

    if len(sys.argv) != 3:
        print("Usage: python script.py <file_list_path> <num_gpus>", file=sys.stderr)
        sys.exit(1)
        
    file_list_path = sys.argv[1]
    
    try:
        NUM_GPU_PROCESSES = int(sys.argv[2])
        if NUM_GPU_PROCESSES <= 0:
             raise ValueError("Number of GPUs must be greater than 0.")
    except ValueError as e:
        print(f"ERROR: Invalid number of GPUs provided: {e}", file=sys.stderr)
        sys.exit(1)

    try:
        with open(file_list_path, 'r') as f:
            svs_files = [line.strip() for line in f if line.strip()]
        
        if not svs_files:
            print(f"ERROR: File list is empty or contained no valid paths: {file_list_path}", file=sys.stderr)
            sys.exit(1)
            
    except FileNotFoundError:
        print(f"ERROR: File list not found at: {file_list_path}", file=sys.stderr)
        sys.exit(1)

    # BASE_DIR = pathlib.Path("/data/isip/data/tuh_dpath_breast/deidentified/v3.0.0/svs")
    # try:
    #     svs_files = list(BASE_DIR.rglob("*.svs"))[:N_IMAGES]
    #     if not svs_files:
    #          print(f"ERROR: No SVS files found in {BASE_DIR}", file=sys.stderr)
    #          sys.exit(1)
    # except FileNotFoundError:
    #     print(f"ERROR: Base directory not found: {BASE_DIR}", file=sys.stderr)
    #     sys.exit(1)

    job_queue = mp.Queue()
    for f in svs_files:
        job_queue.put(f)

    hist_queue = mp.Queue()

    gpu_processes = []
    detected_gpus = torch.cuda.device_count()
    num_gpus_to_use = min(NUM_GPU_PROCESSES, detected_gpus if detected_gpus > 0 else 1)
    
    for gid in range(num_gpus_to_use):
        p = mp.Process(target=gpu_worker, args=(gid, job_queue, hist_queue))
        gpu_processes.append(p)
        p.start()
        
    print(f"Spawned {num_gpus_to_use} independent GPU Workers.", flush=True)

    for p in gpu_processes:
        p.join()
    print("\nAll GPU Workers finished. Aggregating results...", flush=True)

    all_histograms = []
    for _ in range(num_gpus_to_use):
        all_histograms.append(hist_queue.get())

    total_histogram = np.sum(all_histograms, axis=0)
    final_mean, final_std = calculate_stats_from_histogram(total_histogram)

    print("\n--- FINAL GLOBAL RESULTS ---", flush=True)
    print(f"Total Combined Histogram: {total_histogram}", flush=True)
    print(f"Final Log Mag Mean: {final_mean:.4f}", flush=True)
    print(f"Final Log Mag Std Dev: {final_std:.4f}", flush=True)
    print("PROGRAM END\n", flush=True)


if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    main()