import os
import gc
import copy
import random
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

# --- PROJECT IMPORTS ---
import scripts.config as config
import scripts.data_setup as data_setup
import scripts.datasets_classes as ds_classes
from quantization.apply import (
    apply_simple_ptq,
    apply_learned_step_size_quantization,
    apply_rotation_lsq
)

# ==============================================================================
# 0. CONFIGURATION
# ==============================================================================
SEEDS = [42, 1337, 2024, 55, 20524] # 5 Samples for distribution
WEIGHT_BITS_LIST = [8, 7, 6, 5, 4]
ACT_BITS = 8

LAYER_NAME = "visual.transformer.resblocks.11"
TARGET_DEVICE = config.TARGET_DEVICE
config.BATCH_SIZE = 64 

# Percentage of ImageNet Val to use for SVD (0.1 = 5,000 images)
# 5,000 images is more than enough for stable singular value distribution.
DATASET_PERCENTAGE = 0.1 

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

# ==============================================================================
# 1. SPECTRAL ENERGY HELPER
# ==============================================================================

def get_intrinsic_energy_spectrum(model, dataloader):
    """Computes the full SVD energy spectrum for the specific model state."""
    model.eval()
    target_module = dict(model.named_modules())[LAYER_NAME]
    activations = []
    
    ctx = {'b': 0}

    def hook(m, i, o):
        t = o.detach().cpu()
        if t.dim() == 3:
            # Robust check for Batch vs Sequence dimension
            # o is usually [Seq, Batch, Dim] or [Batch, Seq, Dim]
            if t.shape[0] == ctx['b']: # [Batch, Seq, Dim]
                activations.append(t[:, 0, :]) # Take CLS token for all in batch
            else: # [Seq, Batch, Dim]
                activations.append(t[0, :, :]) # Take CLS token for all in batch
        else:
            activations.append(t)

    handle = target_module.register_forward_hook(hook)
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="      SVD Accumulation", leave=False):
            ctx['b'] = batch[0].shape[0]
            model.encode_image(batch[0].to(TARGET_DEVICE))
    handle.remove()
    
    X = torch.cat(activations, dim=0).to(TARGET_DEVICE)
    # We only need singular values S for the energy plot
    _, S, _ = torch.linalg.svd(X, full_matrices=False)
    
    s_vals = S.detach().cpu().numpy()
    # Calculate Cumulative Variance
    cum_var = np.cumsum(s_vals**2) / np.sum(s_vals**2)
    
    del X, S; cleanup()
    return cum_var

# ==============================================================================
# 2. MAIN EXECUTION
# ==============================================================================

def main():
    all_energy_records = []

    os.makedirs("spectral_results", exist_ok=True)

    for trial_id, seed in enumerate(SEEDS):
        print(f"\n" + "="*80)
        print(f"STARTING TRIAL {trial_id+1}/{len(SEEDS)} (SEED: {seed})")
        print("="*80)
        
        seed_everything(seed)
        cleanup()

        # 1. Load Baseline Model
        model_raw, tokenizer, preprocess = data_setup.get_model_and_tokenizer()
        
        # 2. Setup Data
        loaders, full_test_loader, _, _ = data_setup.get_dataset_loaders("imagenet1kval", preprocess, get_train=True)
        cal_loader, train_loader = loaders
        class_names = ds_classes.IMAGENET1K_CLASSES # Use explicit class list
        
        # Subsample test set for SVD computation
        indices = list(range(len(full_test_loader.dataset)))
        random.shuffle(indices)
        subset_indices = indices[:int(len(indices) * DATASET_PERCENTAGE)]
        svd_loader = DataLoader(Subset(full_test_loader.dataset, subset_indices), 
                                batch_size=config.BATCH_SIZE, shuffle=False, num_workers=8)

        # 3. Baseline FP32 Energy
        print(f"   [Trial {trial_id+1}] Method: FP32")
        fp32_energy = get_intrinsic_energy_spectrum(model_raw.to(TARGET_DEVICE), svd_loader)
        for rank, val in enumerate(fp32_energy):
            all_energy_records.append({
                "Trial": trial_id, "Seed": seed, "Method": "FP32", 
                "W_Bits": 32, "Rank": rank, "CumVariance": val
            })
        
        # Prep Teacher for LSQ/ROT
        teacher = copy.deepcopy(model_raw).to(TARGET_DEVICE)
        for p in teacher.parameters(): p.requires_grad = False

        # 4. Quantization Loops
        for w_bits in WEIGHT_BITS_LIST:
            print(f"\n   --- Scaling Weights to {w_bits}-bit ---")
            
            # Methods mapping
            methods = [
                ("PTQ", lambda m: apply_simple_ptq(
                    m, calibration_dataloader=cal_loader, target_device=TARGET_DEVICE, 
                    tokenizer=tokenizer, prompts=class_names, weight_bits=w_bits, act_bits=ACT_BITS)),
                
                ("LSQ", lambda m: apply_learned_step_size_quantization(
                    m, training_dataloader=train_loader, calibration_dataloader=cal_loader, 
                    target_device=TARGET_DEVICE, tokenizer=tokenizer, prompts=class_names, 
                    teacher=teacher, weight_bits=w_bits, act_bits=ACT_BITS, **config.LSQ_KWARGS)),
                
                ("ROT+LSQ", lambda m: apply_rotation_lsq(
                    m, training_dataloader=train_loader, calibration_dataloader=cal_loader, 
                    target_device=TARGET_DEVICE, tokenizer=tokenizer, prompts=class_names, 
                    teacher=teacher, weight_bits=w_bits, act_bits=ACT_BITS, **config.LSQ_KWARGS))
            ]

            for name, apply_fn in methods:
                print(f"      [Trial {trial_id+1}] Method: {name} (W{w_bits})")
                try:
                    # Apply quantization to a fresh copy of the raw model
                    m_quant = apply_fn(copy.deepcopy(model_raw).cpu()).to(TARGET_DEVICE)
                    
                    # Compute Energy
                    energy = get_intrinsic_energy_spectrum(m_quant, svd_loader)
                    
                    # Store
                    for rank, val in enumerate(energy):
                        all_energy_records.append({
                            "Trial": trial_id, "Seed": seed, "Method": name, 
                            "W_Bits": w_bits, "Rank": rank, "CumVariance": val
                        })
                    
                    del m_quant; cleanup()
                except Exception as e:
                    print(f"      [Error] Failed {name} at W{w_bits}: {e}")

        # Trial cleanup
        del model_raw, teacher, svd_loader; cleanup()

    # Save Aggregate Results
    df = pd.DataFrame(all_energy_records)
    output_path = "spectral_results/bitwidth_scaling_energy_distribution.csv"
    df.to_csv(output_path, index=False)
    
    print(f"\n{'='*80}")
    print(f"DONE. Results saved to: {output_path}")
    print(f"{'='*80}")

if __name__ == "__main__":
    main()