# run_scaling_study.py

import os
import csv
import copy
import gc
import torch
import logging
import numpy as np
from tqdm import tqdm
import torch.utils.data as data

# --- 1. OVERRIDE CONFIGURATION ---
import scripts.config as config

# Force Model: OpenCLIP ViT-B-32-quickgelu (OpenAI)
config.MODEL_CONFIG = {'type': 'open_clip', 'arch': 'ViT-B-32-quickgelu', 'data': 'openai'}
config.CLIP_MODEL_ARCHITECTURE = config.MODEL_CONFIG['arch']
config.CLIP_MODEL_PRETRAINED_DATASET = config.MODEL_CONFIG['data']

# Optimization & Resource Settings
config.EVAL_DATASET_SUBSAMPLE_RATIO = 1
config.USE_AMP = True  
config.ENABLE_TEACHER = True # We use the model_fp32 copy as a reference to save VRAM

# --- INTERNAL IMPORTS ---
import scripts.data_setup as data_setup
import scripts.evaluation as evaluation
from quantization.apply import apply_rotation_lsq, apply_learned_step_size_quantization

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)

# --- CONSTANTS ---
PROXY_DATASETS = ["CC3M", "YFCC", "SBU"]
SAMPLE_COUNTS = [1, 10, 100, 1000, 10000]
EVAL_DATASETS = ["imagenet1kval", "cifar100"]
NUM_SEEDS = 5
OUTPUT_CSV = "results_scaling_hybrid_study.csv"
CSV_HEADERS = ["Proxy", "Samples", "Seed", "Eval_Dataset", "Accuracy", "Method"]

# Hybrid Configuration: Contrastive (0.5) + MSE Distillation (0.5)
HYBRID_KWARGS = {
    'learning_rate': 1e-6,
    'lsq_learning_rate': 1e-4,
    'total_steps': 100,
    'main_loss_weight': 0.5, # Contrastive
    'distill_weight': 0.5,   # MSE Feature Distillation
}

class CalibrationListDataset(data.Dataset):
    def __init__(self, data_list): self.data_list = data_list
    def __len__(self): return len(self.data_list)
    def __getitem__(self, idx): return self.data_list[idx]

def get_proxy_loader(proxy_name, n_samples, preprocess):
    path_info = config.PROXY_DATASETS.get(proxy_name)
    if not path_info: return None
    loader_stream = data_setup.create_train_iterable(str(path_info[0]), preprocess, batch_size=1)
    collected_data = []
    iterator = iter(loader_stream)
    try:
        with torch.no_grad():
            for _ in range(n_samples):
                try: batch = next(iterator)
                except StopIteration: break
                collected_data.append((batch[0].squeeze(0), batch[1][0]))
    except Exception: return None
    bs = min(32, len(collected_data)) if len(collected_data) > 0 else 1
    return data.DataLoader(CalibrationListDataset(collected_data), batch_size=bs, shuffle=True)

def main():
    data_setup.set_seed(config.RANDOM_SEED)
    model_fp32, tokenizer, preprocess = data_setup.get_model_and_tokenizer()
    device = config.TARGET_DEVICE

    # 1. Setup Evaluation Suite
    eval_suite = {}
    for d_name in EVAL_DATASETS:
        print(f"Initializing {d_name}...")
        (_, _), loader, class_names, template = data_setup.get_dataset_loaders(d_name, preprocess, get_train=False)
        
        # Pre-compute text features
        texts = [template.format(c) for c in class_names]
        model_fp32.eval()
        feats_list = []
        with torch.no_grad(), torch.amp.autocast('cuda', enabled=config.USE_AMP):
            for i in range(0, len(texts), 512):
                tokens = tokenizer(texts[i:i+512]).to(device)
                f = model_fp32.encode_text(tokens)
                feats_list.append(f / f.norm(dim=-1, keepdim=True))
        
        eval_suite[d_name] = {
            'loader': loader,
            'txt_feats': torch.cat(feats_list, dim=0)
        }

    # 2. Init CSV and Baseline
    with open(OUTPUT_CSV, 'w', newline='') as f:
        csv.writer(f).writerow(CSV_HEADERS)

    print("\n" + "="*60)
    print("EVALUATING FP32 BASELINES")
    for d_name, data in eval_suite.items():
        acc = evaluation.run_comprehensive_evaluation(model_fp32, model_fp32, data['loader'], data['txt_feats'], device)["Zero-Shot Accuracy"]
        print(f"  {d_name:<15} Accuracy: {acc:.2%}")
        with open(OUTPUT_CSV, 'a', newline='') as f:
            csv.writer(f).writerow(["Baseline", 0, 0, d_name, acc, "FP32"])
    print("="*60 + "\n")

    methods = [
        ("LSQ", apply_learned_step_size_quantization),
        ("Rot+LSQ", apply_rotation_lsq)
    ]

    for proxy in PROXY_DATASETS:
        for n_samples in SAMPLE_COUNTS:
            print(f"\n>>> Scenario: {proxy} | Samples: {n_samples}")
            
            for seed_idx in range(NUM_SEEDS):
                current_seed = config.RANDOM_SEED + seed_idx
                data_setup.set_seed(current_seed)
                loader = get_proxy_loader(proxy, n_samples, preprocess)
                if not loader: continue

                for m_name, apply_fn in methods:
                    temp_model = copy.deepcopy(model_fp32)
                    
                    # Apply Hybrid Tuning 
                    # Note: We pass model_fp32 as the teacher here because it is already 
                    # in memory and frozen. Since it's identical to the starting state 
                    # of temp_model, it serves perfectly for feature distillation.
                    q_model = apply_fn(
                        temp_model, training_dataloader=loader, calibration_dataloader=loader,
                        target_device=device, tokenizer=tokenizer, prompts=None, teacher=model_fp32,
                        quantize_text=False, weight_bits=8, act_bits=8, **HYBRID_KWARGS
                    )

                    # Multi-Dataset Evaluation
                    for d_name, d_info in eval_suite.items():
                        acc = evaluation.run_comprehensive_evaluation(
                            q_model, model_fp32, d_info['loader'], d_info['txt_feats'], device
                        )["Zero-Shot Accuracy"]

                        with open(OUTPUT_CSV, 'a', newline='') as f:
                            csv.writer(f).writerow([proxy, n_samples, current_seed, d_name, acc, m_name])
                        
                        print(f"  [Seed {current_seed}] {m_name:<8} on {d_name:<15}: {acc:.2%}")
                    
                    del q_model, temp_model
                    gc.collect(); torch.cuda.empty_cache()

if __name__ == "__main__":
    main()