import os
import copy
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

import scripts.config as config
import scripts.data_setup as data_setup
import scripts.evaluation as evaluation
from quantization.apply import apply_rotation_lsq

# --- CONFIGURATION ---
config.TARGET_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
config.USE_AMP = True

# Datasets
DATASETS_TO_PROCESS = ["imagenet1kval", "cifar100", "cifar10", "sun397"]

# Standard Models
MODELS_TO_TEST = {
    "WIT (OpenAI)": "CLIP_ViT-B-32_OpenAI",
    "LAION": "CLIP_ViT-B-32_Laion2B"
}

PROXY_KEY = "CC3M"

# QAT Settings
ROT_LSQ_KWARGS = {
    'learning_rate': 1e-6,      
    'lsq_learning_rate': 1e-4,  
    'total_steps': 100,        
    'weight_bits': 8,           
    'act_bits': 8,              
    'main_loss_weight': 1.0,
    'distill_weight': 0.0,      
    'quant_scope': 'ALL'        
}

# --- HELPER FUNCTIONS ---

def get_logits_and_targets(model, loader, tokenizer, device, class_names, template, quantize_text_flag):
    """
    Get raw logits and targets. 
    Handles dynamic quantization of text encoder if active.
    """
    model.eval()
    logits_list = []
    labels_list = []
    
    # Text Feature Encoding
    texts = [template.format(c) for c in class_names]
    
    # If text is quantized, we must run it through the model's (quantized) encoder
    # If it's FP32, it's standard. The model object handles the layers.
    with torch.no_grad():
        text_feats = []
        for i in range(0, len(texts), 256):
            chunk = texts[i:i+256]
            tokens = tokenizer(chunk).to(device)
            f = model.encode_text(tokens)
            f = f / f.norm(dim=-1, keepdim=True)
            text_feats.append(f)
        text_features = torch.cat(text_feats, dim=0)

    # Inference Loop
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Inference", leave=False):
            images = images.to(device)
            labels = labels.to(device)
            img_feats = model.encode_image(images)
            img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)
            l = model.logit_scale.exp() * img_feats @ text_features.t()
            logits_list.append(l)
            labels_list.append(labels)

    return torch.cat(logits_list), torch.cat(labels_list)

def calculate_standard_ece(logits, targets, n_bins=15):
    """ Standard ECE Calculation """
    probs = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(probs, 1)
    accuracies = predictions.eq(targets)
    bin_boundaries = torch.linspace(0, 1, n_bins + 1).to(logits.device)
    ece = 0.0
    for i in range(n_bins):
        lower, upper = bin_boundaries[i], bin_boundaries[i+1]
        in_bin = (confidences > lower) & (confidences <= upper)
        prop_in_bin = in_bin.float().mean().item()
        if prop_in_bin > 0:
            acc = accuracies[in_bin].float().mean().item()
            conf = confidences[in_bin].mean().item()
            ece += np.abs(conf - acc) * prop_in_bin
    return ece * 100.0

def calculate_metrics_by_fixed_group(logits, targets, group_indices, n_bins=15):
    """ Metrics for trajectory tracking """
    probs = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(probs, 1)
    accuracies = predictions.eq(targets).float()
    bin_confs = [None] * n_bins
    bin_accs = [None] * n_bins
    for i in range(n_bins):
        mask = (group_indices == i)
        if mask.any():
            bin_accs[i] = accuracies[mask].mean().item()
            bin_confs[i] = confidences[mask].mean().item()
    return bin_confs, bin_accs

def main():
    print("--- STARTING EVOLUTION (VISION + TEXT QUANTIZED) ---")
    print(f"Loading Proxy: {PROXY_KEY}...")
    proxy_path, _ = config.PROXY_DATASETS[PROXY_KEY]
    _, _, preprocess = data_setup.get_model_and_tokenizer() 
    
    cal_loader = data_setup.create_train_iterable(str(proxy_path), preprocess, batch_size=config.BATCH_SIZE)
    train_loader = data_setup.create_train_iterable(str(proxy_path), preprocess, batch_size=config.BATCH_SIZE)

    results_data = {d: {} for d in DATASETS_TO_PROCESS}

    for display_name, model_key in MODELS_TO_TEST.items():
        print(f"\n\n=== Processing Model: {display_name} ===")
        
        os.environ["TARGET_MODEL_KEY"] = model_key
        config.MODEL_CONFIG = config.MODEL_ZOO[model_key]
        model, tokenizer, _ = data_setup.get_model_and_tokenizer()
        model.eval()
        teacher = copy.deepcopy(model)
        teacher.eval()
        for p in teacher.parameters(): p.requires_grad = False

        fp32_bin_indices = {} 
        
        # --- 1. FP32 ---
        print(">> Phase 1: FP32 Baseline")
        res_fp32 = {}
        for d_name in DATASETS_TO_PROCESS:
            print(f"   {d_name}...", end="")
            _, loader, cnames, tmpl = data_setup.get_dataset_loaders(d_name, preprocess, get_train=False)
            logits, targets = get_logits_and_targets(model, loader, tokenizer, config.TARGET_DEVICE, cnames, tmpl, False)
            
            probs = F.softmax(logits, dim=1)
            confs, _ = torch.max(probs, 1)
            n_bins = 15
            bin_boundaries = torch.linspace(0, 1, n_bins + 1).to(config.TARGET_DEVICE)
            indices = torch.bucketize(confs, bin_boundaries, right=True) - 1
            indices.clamp_(0, n_bins - 1)
            fp32_bin_indices[d_name] = indices
            
            confs_out, accs_out = calculate_metrics_by_fixed_group(logits, targets, indices, n_bins)
            std_ece = calculate_standard_ece(logits, targets)
            res_fp32[d_name] = {"confs": confs_out, "accs": accs_out, "ece": std_ece}
            print(f" ECE: {std_ece:.2f}%")

        # --- 2. QAT (VISION + TEXT) ---
        print(">> Phase 2: QAT (Vision + Text Quantization)")
        # ENABLE TEXT QUANTIZATION HERE
        model_qat = apply_rotation_lsq(
            model, training_dataloader=train_loader, calibration_dataloader=cal_loader,
            target_device=config.TARGET_DEVICE, tokenizer=tokenizer, prompts=None, 
            teacher=teacher, quantize_text=True, **ROT_LSQ_KWARGS
        )
        
        res_qat = {}
        for d_name in DATASETS_TO_PROCESS:
            print(f"   {d_name}...", end="")
            _, loader, cnames, tmpl = data_setup.get_dataset_loaders(d_name, preprocess, get_train=False)
            # Pass quantize_text_flag=True context (logic handled by model object internally mostly)
            logits, targets = get_logits_and_targets(model_qat, loader, tokenizer, config.TARGET_DEVICE, cnames, tmpl, True)
            confs_out, accs_out = calculate_metrics_by_fixed_group(logits, targets, fp32_bin_indices[d_name], 15)
            std_ece = calculate_standard_ece(logits, targets)
            res_qat[d_name] = {"confs": confs_out, "accs": accs_out, "ece": std_ece}
            print(f" ECE: {std_ece:.2f}%")

        # --- 3. TUNING ---
        print(">> Phase 3: Logit Tuning")
        model_for_tuning = copy.deepcopy(model_qat)
        model_tuned = evaluation.tune_logit_scale(
            model_for_tuning, tuning_dataloader=cal_loader, tokenizer=tokenizer,
            class_names=None, device=config.TARGET_DEVICE, teacher_model=teacher
        )
        
        res_tuned = {}
        for d_name in DATASETS_TO_PROCESS:
            print(f"   {d_name}...", end="")
            _, loader, cnames, tmpl = data_setup.get_dataset_loaders(d_name, preprocess, get_train=False)
            logits, targets = get_logits_and_targets(model_tuned, loader, tokenizer, config.TARGET_DEVICE, cnames, tmpl, True)
            confs_out, accs_out = calculate_metrics_by_fixed_group(logits, targets, fp32_bin_indices[d_name], 15)
            std_ece = calculate_standard_ece(logits, targets)
            res_tuned[d_name] = {"confs": confs_out, "accs": accs_out, "ece": std_ece}
            print(f" ECE: {std_ece:.2f}%")

        # Store
        for d_name in DATASETS_TO_PROCESS:
            results_data[d_name][display_name] = {
                "FP32": res_fp32[d_name],
                "QAT": res_qat[d_name],
                "Tuned": res_tuned[d_name]
            }

        del model, teacher, model_qat, model_tuned, model_for_tuning, fp32_bin_indices
        torch.cuda.empty_cache()

    print("\n\n" + "="*40)
    print("DATA FOR PLOTTING")
    print("="*40)
    print(results_data)
    print("="*40)

if __name__ == "__main__":
    main()