#!/usr/bin/env python3
import sys, os
from time import perf_counter
import numpy as np
import torch
import torch.nn.functional as F
import nedc_image_tools

# -----------------------------
# Config
# -----------------------------
FRAME_SIZE  = (128, 128)
WINDOW_SIZE = (256, 256)
NUM_BINS    = 16
BATCH_SIZE  = 1024

GAUSS_K     = 7
GAUSS_SIGMA = 1.0

DEFAULT_LIST = "files.list"

torch.backends.cudnn.benchmark = True


# -----------------------------
# List file reader
# -----------------------------
def read_image_list(list_path: str):
    if not os.path.exists(list_path):
        raise FileNotFoundError(f"List file not found: {list_path}")

    paths = []
    with open(list_path, "r") as f:
        for line in f:
            s = line.strip()
            if not s or s.startswith("#"):
                continue
            paths.append(s)

    if not paths:
        raise RuntimeError("List file contained no image paths.")
    return paths


# -----------------------------
# Window extraction (CPU)
# -----------------------------
def generateTopLeftFrameCoordinates(height: int, width: int, frame_size: tuple, window_size: tuple):
    sx, sy = frame_size
    wx, wy = window_size
    coords = []
    for x in range(0, width  - wx + 1, sx):
        for y in range(0, height - wy + 1, sy):
            coords.append((x, y))
    return coords


def windowRGBValues(image_file: str, frame_size: tuple, window_size: tuple):
    img = nedc_image_tools.Nil()
    if not img.open(image_file):
        return []  # fail gracefully

    width, height = img.get_dimension()
    coords = generateTopLeftFrameCoordinates(height, width, frame_size, window_size)

    windows = img.read_data_multithread(
        coords,
        npixy=window_size[1],
        npixx=window_size[0],
        color_mode="RGB"
    )
    img.close()
    return windows


# -----------------------------
# GPU helpers
# -----------------------------
def gaussian_kernel_2d_torch(k: int, sigma: float, device):
    assert k % 2 == 1
    r = k // 2
    ax = torch.arange(-r, r + 1, device=device, dtype=torch.float32)
    xx, yy = torch.meshgrid(ax, ax, indexing="ij")
    ker = torch.exp(-(xx**2 + yy**2) / (2.0 * sigma**2))
    ker = ker / ker.sum()
    return ker.view(1, 1, k, k)


@torch.no_grad()
def gpu_pipeline_logmag(batch_rgb_u8: np.ndarray, device, gauss_kernel):
    """
    batch_rgb_u8: (B,H,W,3) uint8
    returns log_mag: (B,H,W) float32 on GPU
    """
    rgb = torch.from_numpy(batch_rgb_u8).to(device=device, dtype=torch.float32)
    rgb = rgb.permute(0, 3, 1, 2)  # (B,3,H,W)

    r = rgb[:, 0:1]
    g = rgb[:, 1:2]
    b = rgb[:, 2:3]

    gray = 0.299 * r + 0.587 * g + 0.114 * b  # (B,1,H,W)
    gray = gray - gray.mean(dim=(2, 3), keepdim=True)

    pad = GAUSS_K // 2
    gray = F.pad(gray, (pad, pad, pad, pad), mode="reflect")
    smoothed = F.conv2d(gray, gauss_kernel, padding=0)  # (B,1,H,W)
    smoothed2d = smoothed.squeeze(1)                    # (B,H,W)

    F2 = torch.fft.fft2(smoothed2d)
    mag = torch.abs(F2)
    log_mag = torch.log1p(mag)

    return log_mag


@torch.no_grad()
def hist_per_window_gpu(log_mag: torch.Tensor, bin_edges: torch.Tensor):
    """
    log_mag: (B,H,W) on GPU
    bin_edges: (NUM_BINS+1,) on GPU
    returns counts: (B,NUM_BINS) float32 on GPU
    """
    B = log_mag.shape[0]
    flat = log_mag.reshape(B, -1)
    flat = torch.nan_to_num(flat, nan=0.0, posinf=0.0, neginf=0.0)

    idx = torch.bucketize(flat, bin_edges, right=False) - 1
    idx = idx.clamp(0, bin_edges.numel() - 2)  # [0..NUM_BINS-1]

    nb = bin_edges.numel() - 1
    offsets = (torch.arange(B, device=flat.device).view(B, 1) * nb)
    glob = (idx + offsets).reshape(-1)

    counts = torch.bincount(glob, minlength=B * nb).float().view(B, nb)
    return counts


# -----------------------------
# Dataset processing
# -----------------------------
def dataset_pass1_global_range(image_paths, device, gauss):
    """
    Find global min/max of log_mag across ALL windows of ALL images.
    """
    gmin, gmax = None, None
    total_images = 0
    bad = 0
    total_windows = 0

    t0 = perf_counter()
    for ip, path in enumerate(image_paths, 1):
        if not os.path.exists(path):
            bad += 1
            continue

        windows = windowRGBValues(path, FRAME_SIZE, WINDOW_SIZE)
        if not windows:
            bad += 1
            continue

        total_images += 1
        N = len(windows)
        total_windows += N

        for start in range(0, N, BATCH_SIZE):
            end = min(start + BATCH_SIZE, N)
            batch_np = np.stack(windows[start:end], axis=0).astype(np.uint8)
            log_mag = gpu_pipeline_logmag(batch_np, device, gauss)
            torch.cuda.synchronize()

            bmin = torch.amin(log_mag).item()
            bmax = torch.amax(log_mag).item()

            gmin = bmin if gmin is None else min(gmin, bmin)
            gmax = bmax if gmax is None else max(gmax, bmax)

        if ip % 10 == 0 or ip == len(image_paths):
            print(f"[pass1] images {ip}/{len(image_paths)} done, windows so far={total_windows}")

    t1 = perf_counter()
    return gmin, gmax, total_images, bad, total_windows, (t1 - t0)


def dataset_pass2_global_hist(image_paths, device, gauss, bin_edges):
    """
    Compute global mean/std per bin across ALL windows of ALL images.
    Uses fixed bin_edges.
    """
    nb = NUM_BINS
    sum_p  = torch.zeros(nb, device=device, dtype=torch.float64)
    sum_p2 = torch.zeros(nb, device=device, dtype=torch.float64)

    total_windows = 0
    total_images = 0
    bad = 0

    t0 = perf_counter()
    for ip, path in enumerate(image_paths, 1):
        if not os.path.exists(path):
            bad += 1
            continue

        windows = windowRGBValues(path, FRAME_SIZE, WINDOW_SIZE)
        if not windows:
            bad += 1
            continue

        total_images += 1
        N = len(windows)

        for start in range(0, N, BATCH_SIZE):
            end = min(start + BATCH_SIZE, N)
            batch_np = np.stack(windows[start:end], axis=0).astype(np.uint8)

            log_mag = gpu_pipeline_logmag(batch_np, device, gauss)
            counts = hist_per_window_gpu(log_mag, bin_edges)  # (B,nb)

            denom = counts.sum(dim=1, keepdim=True).clamp_min(1.0)
            p = (counts / denom).double()  # (B,nb)

            sum_p  += p.sum(dim=0)
            sum_p2 += (p * p).sum(dim=0)

            total_windows += (end - start)

        if ip % 10 == 0 or ip == len(image_paths):
            print(f"[pass2] images {ip}/{len(image_paths)} done, windows so far={total_windows}")

    torch.cuda.synchronize()
    t1 = perf_counter()

    if total_windows == 0:
        raise RuntimeError("No windows processed in pass2 (check list paths).")

    Nw = float(total_windows)
    mean = sum_p / Nw
    var  = (sum_p2 / Nw) - mean**2
    var  = torch.clamp(var, min=0.0)
    std  = torch.sqrt(var)

    return (mean.float().cpu().numpy(),
            std.float().cpu().numpy(),
            total_images, bad, total_windows,
            (t1 - t0))


# -----------------------------
# Main
# -----------------------------
def main():
    if not torch.cuda.is_available():
        print("FATAL: CUDA not available. Run on a GPU node.")
        sys.exit(1)

    list_path = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_LIST
    image_paths = read_image_list(list_path)

    device = torch.device("cuda:0")
    gauss = gaussian_kernel_2d_torch(GAUSS_K, GAUSS_SIGMA, device)

    print("\n==============================")
    print("Global (ALL images) Averager - One GPU")
    print("==============================")
    print("List file   :", list_path)
    print("Images in list:", len(image_paths))
    print("FRAME_SIZE  :", FRAME_SIZE)
    print("WINDOW_SIZE :", WINDOW_SIZE)
    print("NUM_BINS    :", NUM_BINS)
    print("BATCH_SIZE  :", BATCH_SIZE)
    print("Device      :", device)
    print("==============================\n")

    # PASS 1
    gmin, gmax, ok_imgs1, bad1, win1, t_pass1 = dataset_pass1_global_range(image_paths, device, gauss)

    if gmin is None or gmax is None or (not np.isfinite(gmin)) or (not np.isfinite(gmax)) or gmax <= gmin:
        
        gmin, gmax = 0.0, 1.0
    if abs(gmax - gmin) < 1e-12:
        gmax = gmin + 1e-6

    bin_edges = torch.linspace(gmin, gmax, steps=NUM_BINS + 1, device=device, dtype=torch.float32)

    print("\nPASS 1 done.")
    print("  Valid images:", ok_imgs1, " Bad/missing:", bad1)
    print("  Total windows:", win1)
    print(f"  Global log_mag range: min={gmin:.6f} max={gmax:.6f}")
    print(f"  Time pass1: {t_pass1:.3f} s")

    # PASS 2
    mean_hist, std_hist, ok_imgs2, bad2, win2, t_pass2 = dataset_pass2_global_hist(image_paths, device, gauss, bin_edges)

    print("\n==============================")
    print("FINAL GLOBAL RESULT (ALL IMAGES)")
    print("==============================")
    print("Valid images processed:", ok_imgs2)
    print("Bad/missing images     :", bad2)
    print("Total windows processed:", win2)

    print("\nMean histogram (per bin):")
    print(mean_hist)

    print("\nStd-dev histogram (per bin):")
    print(std_hist)

    #  summary (from mean histogram)
    scalar_mean = float(mean_hist.mean())
    scalar_std  = float(mean_hist.std())
    print("\nScalar summary (from averaged histogram):")
    print("  mean =", scalar_mean)
    print("  std  =", scalar_std)

    print("\nTiming:")
    print(f"  pass1 (range) : {t_pass1:.3f} s")
    print(f"  pass2 (hist)  : {t_pass2:.3f} s")
    print("==============================\n")


if __name__ == "__main__":
    main()
