"""
Vision experiment (CLIP): Foundation Model Assisted Statistical Inference
Includes:
1. Main experiment: Minimax Risk & Sample Efficiency
2. Ablation study (--ablation): Zero-Initialization vs. Random Initialization
"""

import os
import sys
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
import copy
import argparse # For parsing command line arguments
import numpy as np
import matplotlib.pyplot as plt # For plotting

sys.path.append("..")

from utils.seed_utils import set_global_seed
from utils.logging_utils import CSVLogger
from utils.metrics import accuracy, brier_score
from utils.models import LinearHead, ConcatHead, WeightedEnsemble, ResidualModel
from vision.datasets_vision import make_vision_loaders, get_dataset_info

try:
    from transformers import CLIPModel, CLIPProcessor
    import huggingface_hub
except ImportError:
    print("Error: transformers library required.")
    exit(1)

# === Add a Residual Model that supports Random Init ===
class ResidualModelAblation(nn.Module):
    def __init__(self, input_dim, num_classes, init_type='zero'):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        # Initialization logic
        if init_type == 'zero':
            nn.init.zeros_(self.mlp[-1].weight)
            nn.init.zeros_(self.mlp[-1].bias)
        elif init_type == 'random':
            # PyTorch default initialization (Kaiming Uniform)
            pass

    def forward(self, features, bb_logits):
        return bb_logits + self.mlp(features)

def precompute_features(model, processor, loader, class_names, device, desc="Caching"):
    """Precompute image features and black-box logits"""
    model.eval()
    # 1. Precompute text features
    print(" Computing text embeddings...")
    templates = ["a photo of a {}."]
    text_inputs = [t.format(c) for c in class_names for t in templates]
    with torch.no_grad():
        text_tokens = processor(text=text_inputs, return_tensors="pt", padding=True).to(device)
        text_feats = model.get_text_features(**text_tokens)
        text_feats = text_feats / text_feats.norm(p=2, dim=-1, keepdim=True)
        text_feats = text_feats.view(len(class_names), len(templates), -1).mean(dim=1)
        text_feats = text_feats / text_feats.norm(p=2, dim=-1, keepdim=True)

    # 2. Precompute image features
    cache = {"feats": [], "bb_logits": [], "labels": []}
    logit_scale = model.logit_scale.exp()

    with torch.no_grad():
        for images, labels in tqdm(loader, desc=desc):
            if hasattr(processor, 'preprocess'):
                if isinstance(images, list) or not isinstance(images, torch.Tensor):
                    inputs = processor(images=images, return_tensors="pt")
                    pixel_values = inputs["pixel_values"].to(device)
                else:
                    pixel_values = images.to(device)
            else:
                pixel_values = images.to(device)

            img_feats = model.get_image_features(pixel_values=pixel_values)
            img_feats = img_feats / img_feats.norm(p=2, dim=-1, keepdim=True)

            bb_logits = logit_scale * img_feats @ text_feats.T

            cache["feats"].append(img_feats.cpu())
            cache["bb_logits"].append(bb_logits.cpu())
            cache["labels"].append(labels.cpu())

    return {
        "feats": torch.cat(cache["feats"]),
        "bb_logits": torch.cat(cache["bb_logits"]),
        "labels": torch.cat(cache["labels"])
    }

def train_model(model_name, train_data, val_data, input_dim, num_classes, device, epochs=50, lr=1e-3, init_type='zero'):
    """ Train model and return best validation accuracy """
    train_ds = TensorDataset(train_data["feats"], train_data["bb_logits"], train_data["labels"])
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)

    val_feats = val_data["feats"].to(device)
    val_bb = val_data["bb_logits"].to(device)
    val_y = val_data["labels"].to(device)

    if model_name == "scratch":
        model = LinearHead(input_dim, num_classes).to(device)
    elif model_name == "weighted":
        model = WeightedEnsemble(input_dim, num_classes).to(device)
    elif model_name == "concat":
        model = ConcatHead(input_dim, num_classes).to(device)
    elif model_name == "residual":
        # Residual Model that supports Ablation
        model = ResidualModelAblation(input_dim, num_classes, init_type=init_type).to(device)

    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    best_acc = -1.0
    best_state = copy.deepcopy(model.state_dict())

    # Record Norm (only for Ablation)
    norms = []

    for epoch in range(epochs):
        model.train()
        for feats, bb, y in train_loader:
            feats, bb, y = feats.to(device), bb.to(device), y.to(device)
            optimizer.zero_grad()

            if model_name == "scratch":
                out = model(feats)
            else:
                out = model(feats, bb)

            loss = loss_fn(out, y)
            loss.backward()
            optimizer.step()

        # Record Norm of Residual Head (if residual)
        if model_name == "residual":
            with torch.no_grad():
                # Calculate L2 norm of last layer weights
                w_norm = model.mlp[-1].weight.norm().item()
                norms.append(w_norm)

        # Validation
        model.eval()
        with torch.no_grad():
            if model_name == "scratch":
                val_pred = model(val_feats)
            else:
                val_pred = model(val_feats, val_bb)
            acc = accuracy(val_pred.argmax(1).cpu().numpy(), val_y.cpu().numpy())

            if acc > best_acc:
                best_acc = acc
                best_state = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_state)
    return model, best_acc, norms

def evaluate_all_safe(models, data, device, bb_val_acc, val_accuracies, best_alpha=None):
    for m in models.values(): m.eval()
    feats = data["feats"].to(device)
    bb = data["bb_logits"].to(device)
    y = data["labels"].to(device)
    import torch.nn.functional as F

    results = {}
    with torch.no_grad():
        # 1. BB Only
        bb_test_score = accuracy(bb.argmax(1).cpu().numpy(), y.cpu().numpy())
        results["bb_only"] = bb_test_score
        bb_probs = F.softmax(bb, dim=1).cpu().numpy()
        y_np = y.cpu().numpy()
        results["bb_mse"] = brier_score(bb_probs, y_np)

        # 2. Baselines
        for m_name in ["scratch", "weighted", "concat"]:
            if m_name in models:
                if m_name == "scratch": logits = models[m_name](feats)
                else: logits = models[m_name](feats, bb)
                results[m_name] = accuracy(logits.argmax(1).cpu().numpy(), y.cpu().numpy())
                probs = F.softmax(logits, dim=1).cpu().numpy()
                results[f"{m_name}_mse"] = brier_score(probs, y_np)

        # 3. Residual
        if "residual" in models:
            res_logits = models["residual"](feats, bb)
            res_raw_score = accuracy(res_logits.argmax(1).cpu().numpy(), y.cpu().numpy())
            results["residual_raw"] = res_raw_score
            res_probs = F.softmax(res_logits, dim=1).cpu().numpy()
            results["residual_raw_mse"] = brier_score(res_probs, y_np)

            # Safe Selection
            if val_accuracies["residual"] < bb_val_acc:
                results["residual_safe"] = bb_test_score
                results["residual_safe_mse"] = results["bb_mse"]
                results["fallback_rate"] = 1.0
            else:
                results["residual_safe"] = res_raw_score
                results["residual_safe_mse"] = results["residual_raw_mse"]
                results["fallback_rate"] = 0.0

        # 4. Weighted Val-Tuned
        if best_alpha is not None and "scratch" in models:
            scratch_logits = models["scratch"](feats)
            w_logits = best_alpha * bb + (1 - best_alpha) * scratch_logits
            results["weighted_val_tuned"] = accuracy(w_logits.argmax(1).cpu().numpy(), y.cpu().numpy())
            w_probs = F.softmax(w_logits, dim=1).cpu().numpy()
            results["weighted_val_tuned_mse"] = brier_score(w_probs, y_np)

    return results

def run_ablation_study(full_train_cache, test_cache, info, device):
    """ Run Zero-Init vs Random-Init ablation study """
    print("\n=== Running Ablation Study: Zero vs Random Init ===")
    n = 500 # Fixed sample size
    input_dim = 512
    num_classes = info["num_classes"]

    # Prepare data
    indices = torch.randperm(len(full_train_cache["labels"]))[:n]
    current_feats = full_train_cache["feats"][indices]
    current_bb = full_train_cache["bb_logits"][indices]
    current_labels = full_train_cache["labels"][indices]

    n_train = int(0.8 * n)
    train_data = {
        "feats": current_feats[:n_train],
        "bb_logits": current_bb[:n_train],
        "labels": current_labels[:n_train]
    }
    val_data = {
        "feats": current_feats[n_train:],
        "bb_logits": current_bb[n_train:],
        "labels": current_labels[n_train:]
    }

    # 1. Train Zero-Init
    print("Training Residual with Zero-Init...")
    model_zero, acc_zero, norms_zero = train_model(
        "residual", train_data, val_data, input_dim, num_classes, device, init_type='zero'
    )

    # 2. Train Random-Init
    print("Training Residual with Random-Init...")
    model_random, acc_random, norms_random = train_model(
        "residual", train_data, val_data, input_dim, num_classes, device, init_type='random'
    )

    # 3. Evaluate on Test
    models = {"residual_zero": model_zero, "residual_random": model_random}

    # Manual evaluation
    model_zero.eval()
    model_random.eval()
    test_feats = test_cache["feats"].to(device)
    test_bb = test_cache["bb_logits"].to(device)
    test_y = test_cache["labels"].to(device)

    with torch.no_grad():
        out_zero = model_zero(test_feats, test_bb)
        out_random = model_random(test_feats, test_bb)

        final_acc_zero = accuracy(out_zero.argmax(1).cpu().numpy(), test_y.cpu().numpy())
        final_acc_random = accuracy(out_random.argmax(1).cpu().numpy(), test_y.cpu().numpy())

    print(f"\n[Result] Zero-Init Acc: {final_acc_zero:.4f} | Random-Init Acc: {final_acc_random:.4f}")

    # 4. Plot Norm Curves
    output_dir = "outputs/ablation"
    os.makedirs(output_dir, exist_ok=True)

    # Set font size
    plt.rcParams['font.size'] = 12
    plt.rcParams['axes.labelsize'] = 14
    plt.rcParams['axes.titlesize'] = 16
    plt.rcParams['legend.fontsize'] = 12
    plt.rcParams['xtick.labelsize'] = 11
    plt.rcParams['ytick.labelsize'] = 11

    plt.figure(figsize=(8, 5))
    plt.plot(norms_zero, label='Zero-Init', linewidth=2)
    plt.plot(norms_random, label='Random-Init', linewidth=2, linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel(r'Norm of Residual Head $\|\hat{r}\|_2$')
    plt.title(f'Training Dynamics: Residual Norm (n={n})')
    plt.legend()
    plt.grid(True, alpha=0.5)

    save_path = os.path.join(output_dir, 'ablation_zero_init.png')
    plt.savefig(save_path, dpi=300)
    print(f"Saved ablation plot to {save_path}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ablation", action="store_true", help="Run ablation study instead of main experiment")
    args = parser.parse_args()

    dataset_name = "cifar100"
    model_name = "openai/clip-vit-base-patch32"
    seed = 42
    save_dir = "outputs/vision_comparison"

    set_global_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    os.makedirs(save_dir, exist_ok=True)
    logger = CSVLogger(save_dir, "vision_results.csv")

    # HuggingFace Auth
    hf_token = os.getenv("HF_TOKEN", None)
    if hf_token: huggingface_hub.login(token=hf_token)

    # Load Model
    print(f"Loading CLIP: {model_name}")
    model = CLIPModel.from_pretrained(model_name, use_safetensors=True, token=hf_token).to(device)
    processor = CLIPProcessor.from_pretrained(model_name, use_safetensors=True, token=hf_token)
    info = get_dataset_info(dataset_name)

    # Cache Data
    _, _, test_loader = make_vision_loaders(dataset_name, "./data", batch_size=128, num_workers=4)
    test_cache = precompute_features(model, processor, test_loader, info["class_names"], device, desc="Caching Test")

    raw_train_loader, _, _ = make_vision_loaders(dataset_name, "./data", batch_size=128, num_workers=4)
    full_train_cache = precompute_features(model, processor, raw_train_loader, info["class_names"], device, desc="Caching Train")

    # === Branch: Ablation or Main ===
    if args.ablation:
        run_ablation_study(full_train_cache, test_cache, info, device)
        return

    # === Main Experiment Loop ===
    labeled_sizes = [100, 200, 500, 1000, 2000]
    for n in labeled_sizes:
        print(f"\n>>> Running Experiment: n={n}")

        indices = torch.randperm(len(full_train_cache["labels"]))[:n]
        current_feats = full_train_cache["feats"][indices]
        current_bb = full_train_cache["bb_logits"][indices]
        current_labels = full_train_cache["labels"][indices]

        n_train = int(0.8 * n)
        train_data = { "feats": current_feats[:n_train], "bb_logits": current_bb[:n_train], "labels": current_labels[:n_train] }
        val_data = { "feats": current_feats[n_train:], "bb_logits": current_bb[n_train:], "labels": current_labels[n_train:] }

        models = {}
        val_accuracies = {}
        input_dim = 512

        val_bb_feats = val_data["bb_logits"].to(device)
        val_y = val_data["labels"].to(device)
        bb_val_acc = accuracy(val_bb_feats.argmax(1).cpu().numpy(), val_y.cpu().numpy())
        print(f"  BB Val Acc: {bb_val_acc:.4f}")

        for m_name in ["scratch", "weighted", "concat", "residual"]:
            print(f"  Training {m_name}...")
            # Note: train_model return values changed, now includes norms
            model, val_acc, _ = train_model(m_name, train_data, val_data, input_dim, info["num_classes"], device)
            models[m_name] = model
            val_accuracies[m_name] = val_acc

        # Weighted Val-Tuned Logic
        scratch_model = models["scratch"]
        scratch_model.eval()
        val_feats = val_data["feats"].to(device)
        best_alpha = 0.0
        best_val_acc = -1.0
        with torch.no_grad():
            scratch_logits = scratch_model(val_feats)
            for alpha in [0.0, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0]:
                mixed_logits = alpha * val_bb_feats + (1 - alpha) * scratch_logits
                acc = accuracy(mixed_logits.argmax(1).cpu().numpy(), val_y.cpu().numpy())
                if acc > best_val_acc:
                    best_val_acc = acc
                    best_alpha = alpha
        print(f"  Best Val-Tuned Alpha: {best_alpha}")

        scores = evaluate_all_safe(models, test_cache, device, bb_val_acc, val_accuracies, best_alpha)
        print(f"--- Results (n={n}) ---")
        for k, v in scores.items():
            print(f"{k:>18}: {v:.4f}")
        log_entry = {"n": n}
        log_entry.update(scores)
        logger.log(log_entry)
        logger.save()

if __name__ == "__main__":
    main()
