import os
import gc
import copy
import random
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

# --- PROJECT IMPORTS ---
import scripts.config as config
import scripts.data_setup as data_setup
import scripts.datasets_classes as ds_classes
from quantization.apply import (
    apply_simple_ptq,
    apply_learned_step_size_quantization,
    apply_rotation_lsq
)

# ==============================================================================
# 0. CONFIGURATION
# ==============================================================================
SEEDS = [42, 54, 154 ,45 ,546 , 425, 524, 1524 ,4215 ,5146]  # Multiple seeds for variance
LAYER_NAME = "visual.transformer.resblocks.11"
TARGET_DEVICE = config.TARGET_DEVICE

# --- MEMORY TUNING ---
# We use a small batch for training (LSQ) to prevent OOM
# and your original batch size for the SVD analysis.
TRAIN_BATCH_SIZE = 100   
ANALYSIS_BATCH_SIZE = 64 

# Set to None for FULL ImageNet (50k images)
DATASET_PERCENTAGE = 1

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

# ==============================================================================
# 1. YOUR EXACT SPECTRAL ENGINE
# ==============================================================================
def get_model_individual_energy(model, dataloader):
    """Identical to your provided code to ensure same results."""
    model.eval()
    target_module = dict(model.named_modules())[LAYER_NAME]
    activations = []
    ctx = {'b': 0}

    def hook(m, i, o):
        t = o.detach().cpu()
        if t.dim() == 3:
            if t.shape[0] == ctx['b']: activations.append(t[:, 0, :])
            else: activations.append(t[0, :, :])
        else:
            activations.append(t)

    handle = target_module.register_forward_hook(hook)
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="      SVD Activation Pass", leave=False):
            ctx['b'] = batch[0].shape[0]
            model.encode_image(batch[0].to(TARGET_DEVICE))
    handle.remove()
    
    # Move to GPU for SVD computation
    X = torch.cat(activations, dim=0).to(TARGET_DEVICE)
    _, S, _ = torch.linalg.svd(X, full_matrices=False)
    
    energies_sq = S.pow(2)
    normalized_energies = (energies_sq / energies_sq.sum()).cpu().numpy()
    
    del X, S, energies_sq, activations; cleanup()
    return normalized_energies

# ==============================================================================
# 2. MAIN EXPERIMENT LOOP
# ==============================================================================
def main():
    results_csv_path = "full_trial_energy.csv"
    all_data = []

    # Setup Loaders
    print(">> Initializing Dataloaders...")
    _, _, preprocess = data_setup.get_model_and_tokenizer()
    loaders, full_test_loader, _, _ = data_setup.get_dataset_loaders("imagenet1kval", preprocess, get_train=True)
    cal_loader_raw, train_loader_raw = loaders
    class_names = ds_classes.IMAGENET1K_CLASSES

    # Small batch for Training/Calibration to avoid OOM
    train_loader = DataLoader(train_loader_raw.dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=4)
    cal_loader = DataLoader(cal_loader_raw.dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=4)

    # Original batch for SVD Analysis
    if DATASET_PERCENTAGE is not None:
        indices = list(range(len(full_test_loader.dataset)))
        random.shuffle(indices)
        svd_loader = DataLoader(Subset(full_test_loader.dataset, indices[:int(len(indices)*DATASET_PERCENTAGE)]), 
                                batch_size=ANALYSIS_BATCH_SIZE, shuffle=False, num_workers=8)
    else:
        svd_loader = DataLoader(full_test_loader.dataset, batch_size=ANALYSIS_BATCH_SIZE, shuffle=False, num_workers=8)

    for trial_id, seed in enumerate(SEEDS):
        print(f"\nTrial {trial_id+1}/{len(SEEDS)} (Seed: {seed})")
        seed_everything(seed)
        cleanup()

        model_raw, tokenizer, _ = data_setup.get_model_and_tokenizer()
        
        # 1. FP32 Baseline
        print(">> Computing FP32 SVD...")
        fp32_indiv_energies = get_model_individual_energy(model_raw.to(TARGET_DEVICE), svd_loader)
        
        # Teacher for LSQ (Must stay on GPU)
        teacher = copy.deepcopy(model_raw).to(TARGET_DEVICE)
        for p in teacher.parameters(): p.requires_grad = False

        # 2. Methods Setup
        methods = [
            ("PTQ", lambda m: apply_simple_ptq(
                m, calibration_dataloader=cal_loader, target_device=TARGET_DEVICE, 
                tokenizer=tokenizer, prompts=class_names, weight_bits=8, act_bits=8)),
            
            ("LSQ", lambda m: apply_learned_step_size_quantization(
                m, training_dataloader=train_loader, calibration_dataloader=cal_loader, 
                target_device=TARGET_DEVICE, tokenizer=tokenizer, prompts=class_names, 
                teacher=teacher, weight_bits=8, act_bits=8, **config.LSQ_KWARGS)),
            
            ("ROT+LSQ", lambda m: apply_rotation_lsq(
                m, training_dataloader=train_loader, calibration_dataloader=cal_loader, 
                target_device=TARGET_DEVICE, tokenizer=tokenizer, prompts=class_names, 
                teacher=teacher, weight_bits=8, act_bits=8, **config.LSQ_KWARGS))
        ]

        # 3. Execution & Calculation
        for name, apply_fn in methods:
            print(f">> Applying and Re-computing SVD for {name}...")
            # We pass a CPU copy to the apply_fn, which moves it to GPU inside
            m_q = apply_fn(copy.deepcopy(model_raw).cpu()).to(TARGET_DEVICE)
            q_indiv_energies = get_model_individual_energy(m_q, svd_loader)
            
            # Use your exact calculation logic
            diff_indiv = q_indiv_energies - fp32_indiv_energies
            smoothed_diff = np.convolve(diff_indiv, np.ones(3)/3, mode='same')
            delta_cum_var = np.cumsum(smoothed_diff)
            
            for rank, val in enumerate(delta_cum_var):
                all_data.append({
                    "Trial": trial_id,
                    "Seed": seed,
                    "Method": name,
                    "Rank": rank,
                    "DeltaCumVar": val
                })
            
            del m_q; cleanup()
        
        del model_raw, teacher; cleanup()

        # Incremental Save to CSV (prevents loss if script crashes later)
        pd.DataFrame(all_data).to_csv(results_csv_path, index=False)
        print(f">> Seed {seed} saved to {results_csv_path}")

if __name__ == "__main__":
    main()