import argparse
import os
import torch
import numpy as np
import csv # Added import
from diffusers.models import AutoencoderKL
from tqdm import tqdm
from omegaconf import OmegaConf # Added import

# Assuming dataset.py and its dependencies are in the Python path
# This script should be in the same directory as dataset.py (e.g., /mnt/pvc/REPA/)
from dataset import CellDataModule, to_rgb


def main(cli_args):
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 1. Load DataModule using OmegaConf, similar to train.py
    print("Loading configuration from diffusion_sit_full.yaml...")
    conf = OmegaConf.load("./diffusion_sit_full.yaml")
    conf.batch_size = 64
    data_module = CellDataModule(conf)
    
    dataloader = data_module.get_train_loader()

    # 2. Load VAE
    print(f"Loading VAE from: {cli_args.vae_model_name_or_path}")
    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-mse").to(device)
    # vae = AutoencoderKL.from_pretrained(cli_args.vae_model_name_or_path, subfolder="vae")
    # vae = vae.to(device)
    vae.eval()

    mse_loss_fn = torch.nn.MSELoss(reduction='none')

    # Accumulators for sum of element-wise squared errors
    total_loss1_sum_sq_error = 0.0
    scenario2_channel_losses_sum_sq_error = torch.zeros(6, device="cpu") # Store on CPU to sum items
    total_loss3_sum_sq_error = 0.0
    
    num_samples_processed = 0
    pixels_per_3_channel_sample = 0 # Will be calculated from the first batch


    print(f"Starting processing using dataloader with {len(dataloader)} batches.")
    for batch_idx, batch_data in enumerate(tqdm(dataloader, desc="Processing batches")):
        # if batch_idx == 5:
        #     break
        # Ensure batch_data is a tuple/list and unpack accordingly
        if isinstance(batch_data, list) or isinstance(batch_data, tuple):
            x, _, _ = batch_data # x shape: (B, 6, H_vae, W_vae), expected range [0,1]
        else: # If it's a dict, assuming 'image' key
            x = batch_data['image']

        x = x.to(device)
        current_batch_size = x.shape[0]

        if current_batch_size == 0:
            continue

        with torch.no_grad():
            # Scenario 1: RGB conversion then VAE
            # Apply to_rgb image by image as requested
            x_rgb_s1_list = [to_rgb(img.cpu()[None], dtype=torch.float32).squeeze(0) for img in x] # img is C,H,W -> [None] makes 1,C,H,W for to_rgb
            x_rgb_s1 = torch.stack(x_rgb_s1_list).to(device) # B,3,H,W, range [0,1]

            # Normalize input for VAE from [0,1] to [-1,1]
            x_rgb_s1_for_vae = x_rgb_s1 * 2.0 - 1.0
            
            encoded_s1 = vae.encode(x_rgb_s1_for_vae).latent_dist.sample()
            recons_s1_neg1_1 = vae.decode(encoded_s1).sample # Output is [-1,1]
            recons_s1_01 = (recons_s1_neg1_1 + 1) / 2.0 # Scale VAE output from [-1,1] to [0,1]
            
            loss1_elements = mse_loss_fn(recons_s1_01, x_rgb_s1) # Compare [0,1] with [0,1]
            total_loss1_sum_sq_error += loss1_elements.sum().item()
            
            if batch_idx == 0: 
                 pixels_per_3_channel_sample = x_rgb_s1.shape[1] * x_rgb_s1.shape[2] * x_rgb_s1.shape[3]


            # Scenario 2: Per-Channel Stacking then VAE
            current_s2_batch_channel_sum_sq_error = torch.zeros(6, device=device)
            for c_idx in range(6): # Iterate over 6 original channels
                single_channel = x[:, c_idx:c_idx+1, :, :] # B,1,H,W, range [0,1]
                stacked_channel = single_channel.repeat(1, 3, 1, 1) # B,3,H,W, range [0,1]
                
                # Normalize input for VAE from [0,1] to [-1,1]
                stacked_channel_for_vae = stacked_channel * 2.0 - 1.0
                
                encoded_s2_ch = vae.encode(stacked_channel_for_vae).latent_dist.sample()
                recons_stacked_s2_neg1_1 = vae.decode(encoded_s2_ch).sample # Output is [-1,1]
                recons_stacked_s2_01 = (recons_stacked_s2_neg1_1 + 1) / 2.0 # Scale VAE output from [-1,1] to [0,1]
                
                loss_s2_ch_elements = mse_loss_fn(recons_stacked_s2_01, stacked_channel) # Compare [0,1] with [0,1]
                current_s2_batch_channel_sum_sq_error[c_idx] = loss_s2_ch_elements.sum()

            scenario2_channel_losses_sum_sq_error += current_s2_batch_channel_sum_sq_error.cpu()


            # Scenario 3: RGB -> Per-RGB-Chan Stack -> VAE -> Mean -> Assemble RGB
            # Apply to_rgb image by image as requested
            x_rgb_orig_s3_list = [to_rgb(img.cpu()[None], dtype=torch.float32).squeeze(0) for img in x]
            x_rgb_orig_s3 = torch.stack(x_rgb_orig_s3_list).to(device) # B,3,H,W, range [0,1]
            
            reconstructed_rgb_components_s3 = []
            for rgb_c_idx in range(3): # Iterate R, G, B channels of x_rgb_orig_s3
                single_rgb_channel = x_rgb_orig_s3[:, rgb_c_idx:rgb_c_idx+1, :, :] # B,1,H,W, range [0,1]
                stacked_rgb_channel = single_rgb_channel.repeat(1, 3, 1, 1) # B,3,H,W, range [0,1]

                # Normalize input for VAE from [0,1] to [-1,1]
                stacked_rgb_channel_for_vae = stacked_rgb_channel * 2.0 - 1.0
                
                encoded_s3_ch = vae.encode(stacked_rgb_channel_for_vae).latent_dist.sample()
                recons_stacked_rgb_s3_neg1_1 = vae.decode(encoded_s3_ch).sample # Output is [-1,1]
                recons_stacked_rgb_s3_01 = (recons_stacked_rgb_s3_neg1_1 + 1) / 2.0 # Scale VAE output from [-1,1] to [0,1]
                
                mean_recons_component = recons_stacked_rgb_s3_01.mean(dim=1, keepdim=True) # B,1,H,W, range [0,1]
                reconstructed_rgb_components_s3.append(mean_recons_component)
            
            final_reconstructed_s3 = torch.cat(reconstructed_rgb_components_s3, dim=1) # B,3,H,W
            
            loss3_elements = mse_loss_fn(final_reconstructed_s3, x_rgb_orig_s3) # B,3,H,W
            total_loss3_sum_sq_error += loss3_elements.sum().item()

        num_samples_processed += current_batch_size
        if cli_args.max_batches is not None and batch_idx + 1 >= cli_args.max_batches:
            print(f"Reached max_batches ({cli_args.max_batches}), stopping.")
            break
            
    if num_samples_processed == 0 or pixels_per_3_channel_sample == 0:
        print("No samples were processed or pixel count could not be determined. Exiting.")
        return

    # Total pixels for any 3-channel image reconstruction across all processed samples
    total_pixels_for_reconstruction = num_samples_processed * pixels_per_3_channel_sample

    avg_loss1 = total_loss1_sum_sq_error / total_pixels_for_reconstruction if total_pixels_for_reconstruction > 0 else 0
    
    # For S2, each original channel was expanded to 3 channels for VAE, so normalization is by the same total_pixels_for_reconstruction
    avg_scenario2_channel_losses = scenario2_channel_losses_sum_sq_error / total_pixels_for_reconstruction if total_pixels_for_reconstruction > 0 else torch.zeros(6)
    
    avg_loss2_overall = avg_scenario2_channel_losses.mean().item() if total_pixels_for_reconstruction > 0 else 0

    avg_loss3 = total_loss3_sum_sq_error / total_pixels_for_reconstruction if total_pixels_for_reconstruction > 0 else 0

    print(f"\n--- Reconstruction Performance (Processed {num_samples_processed} samples) ---")
    print(f"Scenario 1 (6ch -> RGB -> VAE -> RGB): Avg MSE = {avg_loss1:.6e}")
    print(f"Scenario 2 (Per-Original-Channel Stack -> VAE):")
    for i in range(6):
        print(f"  Original Channel {i+1} Avg MSE = {avg_scenario2_channel_losses[i]:.6e}")
    print(f"  Overall Avg MSE for Scenario 2 (mean of 6 channel MSEs) = {avg_loss2_overall:.6e}")
    print(f"Scenario 3 (6ch -> RGB -> Per-RGB-Chan Stack -> VAE -> Mean -> Assemble RGB): Avg MSE = {avg_loss3:.6e}")

    # Save results to CSV
    if cli_args.output_csv_path:
        print(f"Saving results to {cli_args.output_csv_path}")
        # Check if file exists to determine if header is needed
        file_exists = os.path.isfile(cli_args.output_csv_path)
        
        with open(cli_args.output_csv_path, 'a', newline='') as csvfile:
            csv_writer = csv.writer(csvfile)
            if not file_exists:
                csv_writer.writerow(["Scenario", "Channel_Index", "Avg_MSE_Loss", "Num_Samples_Processed", "VAE_Model", "Data_Dir"])

            # Common data for this run
            common_data = [num_samples_processed, cli_args.vae_model_name_or_path, data_module.data_dir] # Changed conf.data.data_dir to data_module.data_dir

            csv_writer.writerow(["Scenario 1", "N/A", f"{avg_loss1:.6e}"] + common_data)
            for i in range(6):
                csv_writer.writerow([f"Scenario 2", f"Channel {i+1}", f"{avg_scenario2_channel_losses[i]:.6e}"] + common_data)
            csv_writer.writerow([f"Scenario 2 Overall", "N/A", f"{avg_loss2_overall:.6e}"] + common_data)
            csv_writer.writerow(["Scenario 3", "N/A", f"{avg_loss3:.6e}"] + common_data)
        print(f"Results saved to {cli_args.output_csv_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="VAE Reconstruction Analysis Script")
    parser.add_argument("--vae_model_name_or_path", type=str, default="stabilityai/sd-vae-ft-mse", help="Path or HuggingFace name of the VAE model")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for processing")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of dataloader workers")
    parser.add_argument("--vae_resolution", type=int, default=256, help="Resolution for VAE input images (images will be resized to this)")
    parser.add_argument("--max_batches", type=int, default=None, help="Maximum number of batches to process (for quick testing)")
    parser.add_argument("--output_csv_path", type=str, default="reconstruction_analysis_results.csv", help="Path to save the CSV results file.")

    cli_args = parser.parse_args()
    main(cli_args)
