#! /usr/bin/env python3 import glob import cv2 import nedc_image_tools import nedc_dpath_ann_tools import numpy as np import torch import torch.nn.functional as F import os import traceback # Added for error printing from multiprocessing import Process, Manager import torch.multiprocessing as mp os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" def generateTopLeftFrameCoordinates(height: int, width: int, frame_size: tuple) -> list: """ Generate top-left (x, y) coordinates for each frame over the full image. """ coords = [] for x in range(0, width, frame_size[0]): for y in range(0, height, frame_size[1]): coords.append((x, y)) return coords def make_gaussian_kernel_7x7(device, sigma: float = 1.0): size = 7 coords = torch.arange(size, device=device) - size // 2 yy, xx = torch.meshgrid(coords, coords, indexing="ij") kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2)) kernel = kernel / kernel.sum() return kernel.view(1, 1, size, size) def process_batch_on_gpu(batch_list, device, gaussian_kernel, num_bins, hist_min, hist_max): if not batch_list: return torch.zeros(num_bins, dtype=torch.float64, device=device) # Stack into (B, H, W, 3) batch_np = np.stack(batch_list, axis=0) batch_t = torch.from_numpy(batch_np).to(device) # Grayscale: Y = 0.299R + 0.587G + 0.114B gray = (0.299 * batch_t[:, :, 0] + 0.587 * batch_t[:, :, 1] + 0.114 * batch_t[:, :, 2]) gray_tensor = gray.unsqueeze(1) # Gaussian smoothing smoothed = F.conv2d(gray_tensor, gaussian_kernel, padding=3, groups=1) smoothed_1c = smoothed.squeeze(1) # FFT F_win = torch.fft.fft2(smoothed_1c) F_win_shift = torch.fft.fftshift(F_win, dim=(-2, -1)) mag = torch.abs(F_win_shift) log_mag = torch.log1p(mag).flatten() hist_batch = torch.histc( log_mag, bins=num_bins, min=hist_min, max=hist_max ).to(torch.float64) return hist_batch def worker(rank, world_size, image_files, frame_size, window_size, num_bins, hist_min, hist_max, batch_size, return_dict): try: device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") print(f"[rank {rank}] Using device: {device}") gaussian_kernel = make_gaussian_kernel_7x7(device, sigma=1.0) local_hist = torch.zeros(num_bins, dtype=torch.float64, device=device) # Instantiate reader once per process image_reader = nedc_image_tools.Nil() for img_idx, img_path in enumerate(image_files): if img_idx % world_size != rank: continue print(f"[rank {rank}] Processing: {os.path.basename(img_path)}") image_reader.open(img_path) width, height = image_reader.get_dimension() all_coords = generateTopLeftFrameCoordinates(height, width, frame_size) total_coords = len(all_coords) # Using a chunk size avoid memory exhaustion chunk_size = batch_size * 4 for i in range(0, total_coords, chunk_size): coords_chunk = all_coords[i : i + chunk_size] windows = image_reader.read_data_multithread( coords_chunk, npixy=window_size[1], npixx=window_size[0], color_mode="RGB", num_threads=16 ) # Process this chunk in mini-batches batch_buffer = [] for win in windows: win_np = np.array(win, dtype=np.float32) batch_buffer.append(win_np) if len(batch_buffer) == batch_size: hist_res = process_batch_on_gpu(batch_buffer, device, gaussian_kernel, num_bins, hist_min, hist_max) local_hist += hist_res batch_buffer = [] if len(batch_buffer) > 0: hist_res = process_batch_on_gpu(batch_buffer, device, gaussian_kernel, num_bins, hist_min, hist_max) local_hist += hist_res batch_buffer = [] # Explicitly delete windows list to free RAM for next chunk del windows image_reader.close() return_dict[rank] = local_hist.cpu() print(f"[rank {rank}] Finished successfully.") except Exception as e: print(f"!!! CRITICAL FAILURE IN RANK {rank} !!!") print(traceback.format_exc()) def main(): svs_files = glob.glob( "/data/isip/data/tuh_dpath_breast/deidentified/v3.0.0/svs/**/*.svs", recursive=True ) if not svs_files: print("No SVS files found!") return image_files = svs_files[:100] print(f"Processing {len(image_files)} file(s).") frame_size = (128, 128) window_size = (256, 256) num_bins = 16 hist_min = 0.0 hist_max = 16.612394332885742 batch_size = 256 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Main process using: {device}") num_available = torch.cuda.device_count() if num_available == 0: print("No CUDA devices found.") return world_size = num_available print(f"Spawning {world_size} worker processes.") mp.set_start_method("spawn", force=True) manager = mp.Manager() return_dict = manager.dict() procs = [] for rank in range(world_size): p = Process( target=worker, args=( rank, world_size, image_files, frame_size, window_size, num_bins, hist_min, hist_max, batch_size, return_dict ) ) p.start() procs.append(p) for p in procs: p.join() global_hist = torch.zeros(num_bins, dtype=torch.float64) success_count = 0 for rank in range(world_size): if rank in return_dict: global_hist += return_dict[rank] success_count += 1 else: print(f"Warning: No result received from Rank {rank}") if success_count == 0: print("All workers failed. Exiting.") return hist_np = global_hist.numpy() total_count = hist_np.sum() if total_count > 0: p = hist_np / total_count bin_edges = np.linspace(hist_min, hist_max, num_bins + 1) bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) E_val = np.sum(bin_centers * p) var_val = np.sum(((bin_centers - E_val) ** 2) * p) std_val = np.sqrt(var_val) else: E_val, std_val = float("nan"), float("nan") p, bin_centers = [], [] print("\n=== Global statistics ===") print("Histogram:", p) print("Bin_centers:", bin_centers) print("Global expected value:", E_val) print("Global std:", std_val) if __name__ == "__main__": main()