# compute_spectrum.py

import os
import sys
import gc
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import copy
import open_clip
import pprint

# --- Project Imports ---
sys.path.append(os.getcwd())

import scripts.config as config
import scripts.data_setup as data_setup
from quantization.apply import apply_simple_ptq, apply_rotation_lsq

# --- GLOBAL SETTINGS ---
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
config.USE_AMP = True
config.AMP_DTYPE = torch.float16
TARGET_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- CONFIGURATION ---
CC3M_PATH = str(config.SHARD_PATH_CC3M)
EVAL_DATASETS = ["cifar100", "imagenet1kval"]
BATCH_SIZE = 64
SUBSAMPLE_RATIO = 0.1
W_BITS = 8
A_BITS = 8
QAT_STEPS = 100
LR = 1e-6
LSQ_LR = 1e-4

TARGET_MODELS = [
    {
        "arch": "ViT-B-32-quickgelu",
        "data": "openai",
        "layer": "visual.transformer.resblocks.11", 
        "display_name": "ViT-B-32 (OpenAI)"
    },
    {
        "arch": "convnext_base",
        "data": "laion400m_s13b_b51k",
        "layer": "visual.trunk.stages.3",
        "display_name": "ConvNeXt Base"
    }
]

def cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

class SpectrumAnalyzer:
    def __init__(self, model, layer_name):
        self.model = model
        self.layer_name = layer_name
        self.hook_handle = None
        self.features = []

    def _hook_fn(self, module, input, output):
        grid = None
        if output.ndim == 3: # ViT
            spatial_tokens = output[:, 1:, :] 
            B, N, C = spatial_tokens.shape
            H = W = int(np.sqrt(N))
            grid = spatial_tokens.permute(0, 2, 1).reshape(B, C, H, W)
        elif output.ndim == 4: # ConvNeXt
            s1, s2, s3 = output.shape[1], output.shape[2], output.shape[3]
            if s1 > s2 and s1 > s3: grid = output # NCHW
            else: grid = output.permute(0, 3, 1, 2) # NHWC
        
        if grid is not None:
            self.features.append(grid.detach().cpu())

    def compute_spectrum(self, dataloader):
        self.features = []
        module_dict = dict(self.model.named_modules())
        if self.layer_name not in module_dict: return None

        self.hook_handle = module_dict[self.layer_name].register_forward_hook(self._hook_fn)
        self.model.eval()
        
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                images = batch[0] 
                if isinstance(images, (tuple, list)): images = images[0]
                with torch.amp.autocast('cuda'):
                    self.model.encode_image(images.to(TARGET_DEVICE))
        
        self.hook_handle.remove()
        if not self.features: return None

        # FFT Calculation (Float32)
        all_feats = torch.cat(self.features, dim=0).to(torch.float32)
        fft = torch.fft.fft2(all_feats, norm="ortho")
        fft_shifted = torch.fft.fftshift(fft, dim=(-2, -1))
        magnitude = torch.abs(fft_shifted)
        avg_mag = magnitude.mean(dim=(0, 1))
        
        del all_feats, fft, magnitude
        self.features = []
        cleanup()
        
        # Return as numpy array for easier serialization later
        return avg_mag.cpu().numpy()

def main():
    print("--- Starting Computation... ---", file=sys.stderr)
    cleanup()
    
    # Setup
    _, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='openai', device='cpu')
    
    def get_qat_loader():
        return data_setup.create_train_iterable(CC3M_PATH, preprocess, BATCH_SIZE)

    eval_loaders = {}
    config.EVAL_DATASET_SUBSAMPLE_RATIO = SUBSAMPLE_RATIO
    for d_key in EVAL_DATASETS:
        _, loader, _, _ = data_setup.get_dataset_loaders(d_key, preprocess, get_train=False)
        eval_loaders[d_key] = loader

    # Structure to hold results
    # { "Display Name": { "Dataset": { "fp32": [[...]], "ptq": [[...]], "qat": [[...]] } } }
    FINAL_RESULTS = {}

    for cfg in TARGET_MODELS:
        cleanup()
        display = cfg['display_name']
        print(f"Processing {display}...", file=sys.stderr)
        
        FINAL_RESULTS[display] = {}
        
        # Load Base (FP32)
        base_model, _, _ = open_clip.create_model_and_transforms(
            cfg['arch'], pretrained=cfg['data'], device='cpu', precision='fp32'
        )
        tokenizer = open_clip.get_tokenizer(cfg['arch'])

        # 1. FP32
        model_fp32 = copy.deepcopy(base_model).to(TARGET_DEVICE).eval()
        analyzer = SpectrumAnalyzer(model_fp32, cfg['layer'])
        
        fp32_res = {}
        for d_name, loader in eval_loaders.items():
            fp32_res[d_name] = analyzer.compute_spectrum(loader)
        
        del model_fp32, analyzer
        cleanup()

        # 2. PTQ
        model_ptq = copy.deepcopy(base_model).to(TARGET_DEVICE)
        model_ptq = apply_simple_ptq(
            model_ptq, calibration_dataloader=get_qat_loader(), target_device=TARGET_DEVICE,
            quantize_text=False, weight_bits=W_BITS, act_bits=A_BITS
        )
        analyzer = SpectrumAnalyzer(model_ptq, cfg['layer'])
        
        ptq_res = {}
        for d_name, loader in eval_loaders.items():
            ptq_res[d_name] = analyzer.compute_spectrum(loader)
            
        del model_ptq, analyzer
        cleanup()

        # 3. QAT
        model_qat = copy.deepcopy(base_model).to(TARGET_DEVICE)
        for p in model_qat.parameters(): p.requires_grad = True
        
        model_qat = apply_rotation_lsq(
            model_qat, 
            training_dataloader=get_qat_loader(), 
            calibration_dataloader=get_qat_loader(),
            target_device=TARGET_DEVICE, tokenizer=tokenizer, prompts=None, teacher=None,
            quantize_text=False, weight_bits=W_BITS, act_bits=A_BITS,
            total_steps=QAT_STEPS, learning_rate=LR, lsq_learning_rate=LSQ_LR,
            distill_weight=0.0, main_loss_weight=1.0
        )
        
        analyzer = SpectrumAnalyzer(model_qat, cfg['layer'])
        qat_res = {}
        for d_name, loader in eval_loaders.items():
            qat_res[d_name] = analyzer.compute_spectrum(loader)
            
        del model_qat, analyzer, base_model
        cleanup()

        # Pack results for this model
        for d_name in EVAL_DATASETS:
            if d_name not in FINAL_RESULTS[display]:
                FINAL_RESULTS[display][d_name] = {}
            
            # Convert numpy to list for printing
            FINAL_RESULTS[display][d_name]['fp32'] = fp32_res[d_name].tolist()
            FINAL_RESULTS[display][d_name]['ptq'] = ptq_res[d_name].tolist()
            FINAL_RESULTS[display][d_name]['qat'] = qat_res[d_name].tolist()

    # --- PRINT OUTPUT FOR NOTEBOOK ---
    print("\n" + "="*40 + "\nCOPY EVERYTHING BELOW THIS LINE\n" + "="*40 + "\n")
    
    # Use pprint to make it a valid python dict string
    pp = pprint.PrettyPrinter(width=10000, compact=True)
    pp.pprint(FINAL_RESULTS)

if __name__ == "__main__":
    main()