#!/usr/bin/env python3

"""
ECE 4822 Final Project: Production GPU-Accelerated Image Processing
"""

import sys
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.fft import fft2, fftshift
import gc

# Add NEDC tools to path
sys.path.insert(0, '/home/tur49364/ece_4822/homework/hw_11/patel_jainil/nedc_image_tools')
import nedc_image_tools

class ProductionGPUImageProcessor:
    
    def __init__(self):
        # Set device - use GPU if available, fallback to CPU
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        
        # Set maximum reasonable frames
        self.max_reasonable_frames = 1000 # 
        
        # Create 7x7 Gaussian kernel as specified
        self.gaussian_kernel = self._create_gaussian_kernel_7x7()
        self.gaussian_kernel = self.gaussian_kernel.to(self.device)
        
        # Optimal streaming statistics - no memory accumulation
        self.histogram_sum = np.zeros(16, dtype=np.float64)
        self.histogram_sum_sq = np.zeros(16, dtype=np.float64)
        self.total_frames_processed = 0
    
    # Create 7x7 Gaussian kernel
    def _create_gaussian_kernel_7x7(self):

        sigma = 1.0  # Standard sigma for 7x7 kernel
        kernel_size = 7
        
        # Create coordinate grids
        x = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2
        y = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2
        x_grid, y_grid = torch.meshgrid(x, y, indexing='ij')
        
        # Compute Gaussian values
        kernel = torch.exp(-(x_grid**2 + y_grid**2) / (2 * sigma**2))
        kernel = kernel / kernel.sum()  # Normalize
        
        # Reshape for conv2d: [out_channels, in_channels, height, width]
        return kernel.unsqueeze(0).unsqueeze(0)
    
    # Generate frame coordinates with optimal sampling
    def generate_frame_coordinates(self, height, width, frame_size=(128, 128)):

        coordinates = []
        frame_width, frame_height = frame_size
        
        # Calculate total potential frames
        total_possible = (width // frame_width) * (height // frame_height)
        
        # Choose the target frames based on maximum reasonable frames
        target_frames = min(self.max_reasonable_frames, total_possible)
        
        if total_possible > target_frames:
            # Calculate optimal sampling to get target number of frames
            # Sample uniformly across the image
            cols = width // frame_width
            rows = height // frame_height
            
            # Calculate grid sampling to get approximately target_frames
            total_cells = cols * rows
            sample_ratio = target_frames / total_cells
            
            # Calculate step sizes
            step_x = max(1, int(1.0 / np.sqrt(sample_ratio)))
            step_y = max(1, int(1.0 / np.sqrt(sample_ratio)))
            
            for i, x in enumerate(range(0, width, frame_width)):
                if i % step_x != 0:
                    continue
                for j, y in enumerate(range(0, height, frame_height)):
                    if j % step_y != 0:
                        continue
                    coordinates.append((x, y))
                    if len(coordinates) >= target_frames:
                        break
                if len(coordinates) >= target_frames:
                    break
        else:
            # Image is already small enough - process all frames
            for x in range(0, width, frame_width):
                for y in range(0, height, frame_height):
                    coordinates.append((x, y))
        
        return coordinates
    
    # Process windows in batches on GPU
    def process_windows_batch_gpu(self, windows_rgb, batch_size=16):

        if not windows_rgb:
            return []
        
        all_histograms = []
        
        # Process in batches to manage GPU memory
        for batch_start in range(0, len(windows_rgb), batch_size):
            batch_end = min(batch_start + batch_size, len(windows_rgb))
            batch_windows = windows_rgb[batch_start:batch_end]
            
            batch_histograms = []
            
            for window_rgb in batch_windows:
                if window_rgb is None or len(window_rgb) == 0:
                    continue
                
                try:
                    # Convert RGB channels to tensor [3, H, W]
                    rgb_array = np.array(window_rgb)  # Shape: [3, H, W]
                    
                    # Convert to tensor and move to device
                    rgb_tensor = torch.from_numpy(rgb_array).float().to(self.device)
                    
                    # Process each color channel separately and combine spectra
                    channel_histograms = []
                    
                    for channel in range(3):  # R, G, B channels
                        channel_data = rgb_tensor[channel].unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
                        
                        # Apply 7x7 Gaussian smoother
                        smoothed = F.conv2d(channel_data, self.gaussian_kernel, padding=3)
                        
                        # Compute 2D FFT and magnitude
                        fft_result = fft2(smoothed.squeeze())
                        # Center the zero frequency component
                        fft_shifted = fftshift(fft_result)
                        # Compute magnitude spectrum
                        magnitude = torch.abs(fft_shifted)
                        
                        # Convert to log magnitude spectrum
                        log_magnitude = torch.log(magnitude + 1e-8)  # Add small epsilon to avoid log(0)
                        
                        # Compute histogram with 16 bins as specified
                        # Flatten log magnitude
                        log_mag_flat = log_magnitude.flatten()
                        # Compute histogram
                        hist = torch.histc(log_mag_flat, bins=16, 
                                         min=log_mag_flat.min().item(), 
                                         max=log_mag_flat.max().item())
                        
                        channel_histograms.append(hist.cpu().numpy())
                        
                        # Clean up GPU memory
                        del channel_data, smoothed, fft_result, fft_shifted, magnitude, log_magnitude, log_mag_flat, hist
                    
                    # Combine histograms from all channels (merge spectra as specified)
                    combined_histogram = np.sum(channel_histograms, axis=0)
                    batch_histograms.append(combined_histogram)
                    
                    # Clean up memory aggressively
                    del rgb_tensor, channel_histograms
                    
                except Exception as e:
                    continue
            
            all_histograms.extend(batch_histograms)
            
            # Force garbage collection after each batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
        
        return all_histograms
    
    def _update_streaming_stats(self, histograms):

        if not histograms:
            return
            
        # Convert to numpy for efficient computation
        hist_array = np.array(histograms, dtype=np.float64)
        
        # Update running sums for mean and variance calculation
        self.histogram_sum += np.sum(hist_array, axis=0)
        self.histogram_sum_sq += np.sum(hist_array**2, axis=0)
        self.total_frames_processed += len(histograms)
    
    # Process a single SVS image file
    def process_single_image(self, image_file):
        
        try:
            # Initialize NEDC image reader
            image_reader = nedc_image_tools.Nil()
            image_reader.open(image_file)
            
            # Get image dimensions
            width, height = image_reader.get_dimension()
            
            # Generate frame coordinates for 128x128 frames
            frame_coordinates = self.generate_frame_coordinates(height, width, (128, 128))
            total_frames = len(frame_coordinates)
            
            # Production chunk sizes optimized for competition performance
            total_frames = len(frame_coordinates)
            
            if total_frames > 1000:
                chunk_size = 100  # Larger chunks for better GPU utilization
                batch_size = 2    # Batch processing for efficiency
            elif total_frames > 500:
                chunk_size = 75   # Medium chunks
                batch_size = 2
            else:
                chunk_size = 50   # Standard chunks for smaller images
                batch_size = 4
            
            # Process and stream statistics (no memory accumulation)
            processed_chunks = 0
            # Calculate total chunks for progress tracking
            total_chunks = (len(frame_coordinates) + chunk_size - 1) // chunk_size
            
            # Iterate over chunks for processing
            for chunk_start in range(0, len(frame_coordinates), chunk_size):
                chunk_end = min(chunk_start + chunk_size, len(frame_coordinates))
                chunk_coordinates = frame_coordinates[chunk_start:chunk_end]
                processed_chunks += 1
                
                # Read windows using NEDC tools (256x256 windows)
                try:
                    windows = image_reader.read_data_multithread(
                        chunk_coordinates,
                        npixy=256,  # window height
                        npixx=256,  # window width  
                        color_mode="RGB"
                    )
                    
                    # Process windows on GPU in batches
                    chunk_histograms = self.process_windows_batch_gpu(windows, batch_size=batch_size)
                    
                    # Monitor GPU memory usage (Mb)
                    if torch.cuda.is_available():
                        gpu_memory_used = torch.cuda.memory_allocated() / 1024**2
                        gpu_memory_cached = torch.cuda.memory_reserved() / 1024**2
                    
                    # Stream statistics - update running sums without storing
                    self._update_streaming_stats(chunk_histograms)
                    
                    # Aggressive memory cleanup
                    del windows, chunk_histograms
                    gc.collect()
                    
                except Exception as e:
                    continue
            
        except Exception as e:
            pass
    
    # Process all images in the file list
    def process_image_list(self, file_list_path):

        # Read file list
        with open(file_list_path, 'r') as f:
            image_files = [line.strip() for line in f if line.strip()]
        
        # Process images
        for i, image_file in enumerate(image_files, 1):
            self.process_single_image(image_file)
        
        # Compute final statistics
        self.compute_final_statistics()
    
    # Compute final statistics
    def compute_final_statistics(self):

        if self.total_frames_processed == 0:
            return
        
        n = self.total_frames_processed
        print(f"Processing Summary: {n} frames analyzed")
        print("="*50)
        
        # Compute mean and std using streaming statistics
        mean_histogram = self.histogram_sum / n
        variance_histogram = (self.histogram_sum_sq / n) - (mean_histogram ** 2)
        std_histogram = np.sqrt(np.maximum(variance_histogram, 0))  # Avoid numerical issues
        
        # Output histogram results with readable formatting
        print("Histogram Results (16 bins: mean std_dev)")
        for i, (mean_val, std_val) in enumerate(zip(mean_histogram, std_histogram)):
            print(f"{mean_val:12.6f} {std_val:12.6f}  # Bin {i+1:2d}")
        
        print()
        print("Overall Statistics:")
        print(f"{np.mean(mean_histogram):12.6f}              # Mean of histogram means")
        print(f"{np.std(mean_histogram):12.6f}              # Std of histogram means")

def main():

    if len(sys.argv) != 2:
        sys.exit(1)
    
    file_list_path = sys.argv[1]
    
    if not os.path.exists(file_list_path):
        sys.exit(1)
    
    # Production mode only - full competition processing
    processor = ProductionGPUImageProcessor()
    
    # Process all images
    processor.process_image_list(file_list_path)

if __name__ == "__main__":
    main()