import torch
from torch.utils.data import TensorDataset
import argparse
import json
import os
from models import REPVLM, ProbVLM

# DEVICE = "cuda:0"
DEVICE = f"cuda:{torch.cuda.device_count() - 1}"

def i2t_acc(image_emb, image_uncer, text_emb, text_uncer, all_targets, num_uncer_levels):
    """
    Computes accuracy at different coverage levels by rejecting the most uncertain samples.
    """
    # 1. Get Predictions (Zero-Shot Classification)
    image_emb_norm = image_emb / image_emb.norm(dim=-1, keepdim=True)
    text_emb_norm = text_emb / text_emb.norm(dim=-1, keepdim=True)
    
    # Cosine similarity -> Predictions
    cosine_sim = image_emb_norm @ text_emb_norm.t()
    all_predictions = cosine_sim.argmax(dim=1)

    # 2. Overall Accuracy (Baseline)
    overall_acc = (all_predictions == all_targets).float().mean().item()
    print(f"Overall I2T accuracy (100% Coverage): {overall_acc:.4f}")

    # 3. Prepare Uncertainties
    text_uncer = text_uncer[all_predictions.cpu().numpy()]
    uncertainties = text_uncer + image_uncer

    # 4. Sort by Uncertainty (Low -> High)
    # We want to keep the START of this array (Low Uncertainty = High Confidence)
    sorted_indices = torch.argsort(uncertainties)  # Ascending sort
    
    sorted_preds = all_predictions[sorted_indices]
    sorted_targets = all_targets[sorted_indices]
    
    num_samples = len(uncertainties)
    i2t_accs = []
    
    # Generate coverage steps: e.g., 1.0, 0.9, 0.8 ... 1/num_levels
    # We use linspace to generate fractions of data to KEEP
    coverages = torch.linspace(1, 0, num_uncer_levels + 1, device=DEVICE)[:-1] # Exclude 0.0
    
    for coverage in coverages:
        # Calculate how many samples to keep
        n_keep = int(coverage * num_samples)
        
        if n_keep == 0:
            continue
            
        # Keep the 'n_keep' most confident samples (first n_keep after sorting)
        # Because we sorted Ascending (Low Uncer -> High Uncer)
        current_preds = sorted_preds[:n_keep]
        current_targets = sorted_targets[:n_keep]
        
        acc = (current_preds == current_targets).float().mean().item()
        
        i2t_accs.append(acc)
        print(f"Accuracy at {coverage*100:.0f}% Coverage: {acc:.4f}")
        
    return i2t_accs

@torch.no_grad()
def main(args):
    proxy_ds = args.proxy_ds
    eval_ds = args.eval_ds

    Adaptors = {
        "probvlm": ProbVLM,
        "repvlm": REPVLM,
    }

    weight_path = f"checkpoints/{proxy_ds}/{args.method}/{args.seed}/model.pth"
    adaptor = Adaptors[args.method]().to(DEVICE)
    adaptor.load_state_dict(torch.load(weight_path, map_location=DEVICE))
    adaptor.eval()

    print(f"Loading embeddings from {eval_ds}...", flush=True)
    images = torch.load(f'embeddings/{eval_ds}/image.pth').to(DEVICE)
    targets = torch.load(f'embeddings/{eval_ds}/target.pth').to(DEVICE)
    prompts = torch.load(f'embeddings/{eval_ds}/prompt.pth').to(DEVICE)

    dataset = TensorDataset(images, targets)

    # estimate the uncertainty of images
    image_uncer_list = []
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False)
    
    for idx, batch in enumerate(dataloader):
        batch_img = batch[0].float()
        output = adaptor.adapt_image(batch_img)
        image_uncer_list.append(output[1])

    image_uncer = torch.cat(image_uncer_list, dim=0)
    print(f"Mean Image Uncertainty: {image_uncer.mean():.4f}")

    # estimate the uncertainty of texts (Optional depending on your method)
    output = adaptor.adapt_text(prompts.float())
    text_uncer = output[1]

    # Pass ORIGINAL images for prediction, and calculated uncertainties for sorting
    i2t_accs = i2t_acc(images, image_uncer, prompts, text_uncer, targets, args.uncer_levels)

    # save the results
    result_dir = f"results/cls/{eval_ds}/{proxy_ds}/{args.method}"
    os.makedirs(result_dir, exist_ok=True)
    with open(f"{result_dir}/{args.seed}.json", "w") as f:
        json.dump({
            "i2t": i2t_accs
        }, f)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--proxy_ds", type=str, default="cc")
    parser.add_argument("--eval_ds", type=str, default="cifar100")
    parser.add_argument("--method", type=str, default="repvlm")
    parser.add_argument("--uncer_levels", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    main(args)