import os
import sys
import gc
import copy
import random
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

# --- PROJECT IMPORTS ---
import scripts.config as config
import scripts.data_setup as data_setup
import scripts.datasets_classes as ds_classes
from quantization.apply import (
    apply_simple_ptq,
    apply_learned_step_size_quantization,
    apply_rotation_lsq
)

# ==============================================================================
# 0. CONFIGURATION & SETUP
# ==============================================================================

config.BATCH_SIZE = 32 
DATASET_PERCENTAGE = 1 # 50% of ImageNet Val for robust distribution

if "TARGET_MODEL_KEY" not in os.environ:
    os.environ["TARGET_MODEL_KEY"] = "CLIP_ViT-B-32_OpenAI"

TARGET_DEVICE = config.TARGET_DEVICE
LAYER_NAME = "visual.transformer.resblocks.11" 

# Define the intervals for analysis
INTERVALS = [
    (0, 64), (64, 768)
]

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

# ==============================================================================
# 1. BASIS & PROFILER LOGIC (DISTRIBUTION AWARE)
# ==============================================================================

def compute_model_specific_basis(model, dataloader, layer_name):
    print(f"   Generating SVD Basis for {layer_name}...")
    model.eval()
    target_module = dict(model.named_modules())[layer_name]
    activations = []
    
    def hook(m, i, o):
        t = o.detach().cpu()
        if t.dim() == 3:
            if t.shape[1] <= config.BATCH_SIZE: activations.append(t[0, :, :])
            else: activations.append(t[:, 0, :])
        else: activations.append(t)

    handle = target_module.register_forward_hook(hook)
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="      Collecting Activations", leave=False):
            model.encode_image(batch[0].to(TARGET_DEVICE))
    handle.remove()
    
    X = torch.cat(activations, dim=0).to(TARGET_DEVICE)
    U, S, Vh = torch.linalg.svd(X, full_matrices=False)
    
    # Return Vh and the singular values (S) for energy distribution plotting
    s_vals = S.detach().cpu().numpy()
    del X, U, S; cleanup()
    return Vh, s_vals

class ModelSpecificProfiler:
    def __init__(self, model, layer_name, Vh, tail_ln, tail_proj):
        self.model = model
        self.Vh = Vh  
        self.tail_ln = tail_ln     
        self.tail_proj = tail_proj 
        self.m_layer = dict(model.named_modules())[layer_name]

    def sweep(self, dataloader, intervals, text_features):
        # We store accuracy per batch to show distribution/variance
        batch_results = []
        self.model.eval()

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="      Sampling Batches", leave=False):
                imgs, labels = batch[0].to(TARGET_DEVICE), batch[1].to(TARGET_DEVICE)
                B_curr = imgs.shape[0]
                
                act = []
                def h(m, i, o):
                    if o.dim() == 3:
                        if o.shape[1] == B_curr: act.append(o[0, :, :])
                        else: act.append(o[:, 0, :])
                    else: act.append(o)
                
                handle = self.m_layer.register_forward_hook(h)
                self.model.encode_image(imgs)
                handle.remove()
                
                coeffs = act[0] @ self.Vh.mT
                
                # Accuracy for this specific batch across all intervals
                current_batch_stats = {}
                for (s, e) in intervals:
                    mask = torch.zeros_like(coeffs)
                    mask[:, s:min(e, self.Vh.shape[0])] = 1.0
                    x_rec = (coeffs * mask) @ self.Vh
                    out = self.tail_ln(x_rec)
                    if self.tail_proj is not None: out = out @ self.tail_proj
                    out /= out.norm(dim=-1, keepdim=True)
                    logits = 100.0 * out @ text_features.T
                    
                    acc = (logits.argmax(dim=1) == labels).float().mean().item() * 100
                    current_batch_stats[f"{s}-{e}"] = acc
                
                batch_results.append(current_batch_stats)
        
        return pd.DataFrame(batch_results)

# ==============================================================================
# 2. MAIN
# ==============================================================================

def main():
    seed_everything(42)
    cleanup()
    
    model_raw, tokenizer, preprocess = data_setup.get_model_and_tokenizer()
    print(">> Loading ImageNet-1k...")
    loaders, full_test_loader, _, _ = data_setup.get_dataset_loaders("imagenet1kval", preprocess, get_train=True)
    
    indices = list(range(len(full_test_loader.dataset)))
    random.shuffle(indices)
    subset_indices = indices[:int(len(indices) * DATASET_PERCENTAGE)]
    test_loader = DataLoader(Subset(full_test_loader.dataset, subset_indices), 
                             batch_size=config.BATCH_SIZE, shuffle=False, num_workers=8)
    
    cal_loader, train_loader = loaders
    class_names = ds_classes.IMAGENET1K_CLASSES
    
    model_raw = model_raw.to(TARGET_DEVICE)
    with torch.no_grad():
        toks = tokenizer([f"a photo of a {c}." for c in class_names]).to(TARGET_DEVICE)
        text_features = model_raw.encode_text(toks)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    tail_ln = copy.deepcopy(model_raw.visual.ln_post).to(TARGET_DEVICE)
    tail_proj = copy.deepcopy(model_raw.visual.proj).to(TARGET_DEVICE) if model_raw.visual.proj is not None else None
    
    results_dfs = {}
    energy_stats = {}

    def analyze_model(name, m):
        print(f"\n>> Analyzing {name}...")
        m = m.to(TARGET_DEVICE)
        basis, s_vals = compute_model_specific_basis(m, test_loader, LAYER_NAME)
        energy_stats[name] = s_vals
        prof = ModelSpecificProfiler(m, LAYER_NAME, basis, tail_ln, tail_proj)
        df = prof.sweep(test_loader, INTERVALS, text_features)
        df['Method'] = name
        return df

    # --- Run Models ---
    results_dfs['FP32'] = analyze_model('FP32', model_raw)
    teacher = copy.deepcopy(model_raw).cpu()
    for p in teacher.parameters(): p.requires_grad = False

    # PTQ
    m_ptq = apply_simple_ptq(copy.deepcopy(model_raw).cpu(), calibration_dataloader=cal_loader, 
                             target_device=TARGET_DEVICE, tokenizer=tokenizer, prompts=class_names, **config.SIMPLE_PTQ_KWARGS)
    results_dfs['PTQ'] = analyze_model('PTQ', m_ptq)
    del m_ptq; cleanup()

    # LSQ
    m_lsq = apply_learned_step_size_quantization(copy.deepcopy(model_raw).cpu(), training_dataloader=train_loader, 
                             calibration_dataloader=cal_loader, target_device=TARGET_DEVICE, tokenizer=tokenizer, 
                             prompts=class_names, teacher=teacher.to(TARGET_DEVICE), **config.LSQ_KWARGS)
    results_dfs['LSQ'] = analyze_model('LSQ', m_lsq)
    del m_lsq; cleanup()

    # ROT+LSQ
    m_rot = apply_rotation_lsq(copy.deepcopy(model_raw).cpu(), training_dataloader=train_loader, 
                             calibration_dataloader=cal_loader, target_device=TARGET_DEVICE, tokenizer=tokenizer, 
                             prompts=class_names, teacher=teacher.to(TARGET_DEVICE), **config.LSQ_KWARGS)
    results_dfs['ROT+LSQ'] = analyze_model('ROT+LSQ', m_rot)
    del m_rot; cleanup()

    # ==============================================================================
    # 3. EXPORT & PRINT KEY VALUES
    # ==============================================================================
    
    print("\n" + "="*50)
    print("DATA EXPORT & SUMMARY")
    print("="*50)

    # 1. Combine all raw accuracy results
    full_df = pd.concat(results_dfs.values(), ignore_index=True)
    full_df.to_csv("results_raw_accuracies.csv", index=False)
    print(f">> Raw accuracy data saved to 'results_raw_accuracies.csv' ({len(full_df)} batches total)")

    # 2. Summarize Accuracies (Mean +/- Std)
    summary_df = full_df.groupby('Method').agg(['mean', 'std'])
    print("\n>> Accuracy Summary (Mean ± Std Dev):")
    print("-" * 65)
    print(f"{'Method':<10} | {'Interval 0-64':<20} | {'Interval 64-768':<20}")
    print("-" * 65)
    
    for method in ['FP32', 'PTQ', 'LSQ', 'ROT+LSQ']:
        if method in summary_df.index:
            # Stats for 0-64
            m1 = summary_df.loc[method, ('0-64', 'mean')]
            s1 = summary_df.loc[method, ('0-64', 'std')]
            # Stats for 64-768
            m2 = summary_df.loc[method, ('64-768', 'mean')]
            s2 = summary_df.loc[method, ('64-768', 'std')]
            
            print(f"{method:<10} | {m1:6.2f}% ± {s1:4.2f}        | {m2:6.2f}% ± {s2:4.2f}")
    print("-" * 65)

    # 3. Export Energy Values (Singular Values)
    # Pad arrays to the same length (768) to create a dataframe
    max_len = 768
    energy_data = {}
    for name, s_vals in energy_stats.items():
        padded = np.pad(s_vals, (0, max_len - len(s_vals)), 'constant', constant_values=np.nan)
        energy_data[name] = padded
    
    energy_df = pd.DataFrame(energy_data)
    energy_df.to_csv("results_energy_values.csv", index_label="Rank")
    print(">> Singular Values (Energy) saved to 'results_energy_values.csv'")

    # ==============================================================================
    # 4. PLOTTING
    # ==============================================================================
    
    print("\n>> Generating Plots...")
    
    melted_df = full_df.melt(id_vars=['Method'], var_name='Interval', value_name='Accuracy')
    
    # Calculate Relative Change to FP32 Baseline Mean
    fp32_means = results_dfs['FP32'].mean(numeric_only=True)
    def calc_rel(row):
        base = fp32_means[row['Interval']]
        return ((row['Accuracy'] - base) / (base + 1e-6)) * 100
    
    melted_df['Relative_Change'] = melted_df.apply(calc_rel, axis=1)

    # --- Plot 1: Distribution of Relative Accuracy ---
    plt.figure(figsize=(14, 7))
    sns.lineplot(data=melted_df, x='Interval', y='Relative_Change', hue='Method', 
                 marker='o', err_style="band", errorbar=("ci", 95), linewidth=2.5)
    
    plt.axhline(0, color='black', linestyle='--', label="FP32 Baseline")
    plt.title(f"ImageNet: Relative Accuracy Distribution (Subset: {DATASET_PERCENTAGE*100}%)", fontsize=14)
    plt.ylabel("Relative Accuracy Change (%)", fontsize=12)
    plt.xlabel("Rank Interval", fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("results_imagenet_distribution_relative.png", dpi=300)

    # --- Plot 2: Scree Plot (Energy Distribution) ---
    plt.figure(figsize=(10, 6))
    for name, s_vals in energy_stats.items():
        # Normalized cumulative energy
        cum_energy = np.cumsum(s_vals**2) / np.sum(s_vals**2)
        plt.plot(cum_energy[:128], label=name, linewidth=2) # Zoom into top 128 ranks

    plt.title(f"Intrinsic Feature Energy Distribution (Layer: {LAYER_NAME})", fontsize=14)
    plt.xlabel("Rank Index", fontsize=12)
    plt.ylabel("Cumulative Explained Variance", fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("results_imagenet_energy_distribution.png", dpi=300)
    
    print(">> Done.")

if __name__ == "__main__":
    main()