import argparse
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from wilds import get_dataset
import torchvision.transforms as T
import matplotlib.pyplot as plt
from train_main import get_model, make_transforms


def evaluate_split(model, loader, device):
    """
    Runs the model on all examples in loader and returns per-sample confidences and correctness flags.
    """
    model.eval()
    confidences = []
    corrects = []
    with torch.no_grad():
        for x, y, _ in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            conf, preds = probs.max(dim=1)
            confidences.append(conf.cpu().numpy())
            corrects.append((preds == y).cpu().numpy().astype(float))
    confidences = np.concatenate(confidences)
    corrects = np.concatenate(corrects)
    return confidences, corrects


def selective_accuracy(confidences, corrects):
    """
    Given arrays of confidences and binary correctness, returns coverage levels and selective accuracies.
    """
    idx = np.argsort(-confidences)
    sorted_correct = corrects[idx]
    n = len(corrects)
    coverage = np.arange(1, n + 1) / n
    cum_correct = np.cumsum(sorted_correct)
    acc = cum_correct / np.arange(1, n + 1)
    return coverage, acc


def perfect_upper_bound(a_full, coverage):
    """
    Computes the perfect-ordering upper bound from Definition 1.
    """
    return np.where(coverage <= a_full, 1.0, a_full / coverage)


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate model on WILDS val and test splits with selective accuracy and upper bound, plot and save results"
    )
    parser.add_argument("--arch", default="resnet18", choices=["simple_cnn", "resnet18", "wideresnet"])
    parser.add_argument("--dataset", default="camelyon17", choices=["camelyon17", "fmow"])
    parser.add_argument("--method", default="msp", choices=["msp", "sat"])
    parser.add_argument("--ckpt_dir", default="./checkpoints")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument('--data_dir', required=True,
                        help='Root dir for dataset')
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    args = parser.parse_args()

    device = torch.device(args.device)

    # Load model checkpoint
    extra = (args.method == "sat")
    model = get_model(args.arch, args.dataset, extra_class=extra).to(device)
    ckpt_path = os.path.join(
        args.ckpt_dir, args.dataset,
        f"{args.arch}_{args.dataset}_{args.method}_seed{args.seed}.pt"
    )
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state)

    # Prepare data loaders
    _, ttest = make_transforms(args.dataset)
    dataset_w = get_dataset(dataset=args.dataset, root_dir=args.data_dir, download=False)
    ds_val = dataset_w.get_subset("val", transform=ttest)
    try:
        ds_test = dataset_w.get_subset("test", transform=ttest)
    except ValueError:
        ds_test = None
    loader_val = DataLoader(ds_val, batch_size=args.batch_size, shuffle=False)
    loader_test = DataLoader(ds_test, batch_size=args.batch_size, shuffle=False) if ds_test else None

    # Evaluate validation split
    conf_val, corr_val = evaluate_split(model, loader_val, device)
    cov_val, acc_val = selective_accuracy(conf_val, corr_val)
    a_full_val = corr_val.mean()
    bound_val = perfect_upper_bound(a_full_val, cov_val)

    # Evaluate test split (if exists)
    if loader_test:
        conf_test, corr_test = evaluate_split(model, loader_test, device)
        cov_test, acc_test = selective_accuracy(conf_test, corr_test)
        a_full_test = corr_test.mean()
        bound_test = perfect_upper_bound(a_full_test, cov_test)

    # Save results to pandas DataFrame
    os.makedirs("./results", exist_ok=True)
    records = []
    for cov, acc, bnd in zip(cov_val, acc_val, bound_val):
        records.append({"coverage": cov, "achieved": acc, "bound": bnd, "split": "val"})
    if loader_test:
        for cov, acc, bnd in zip(cov_test, acc_test, bound_test):
            records.append({"coverage": cov, "achieved": acc, "bound": bnd, "split": "test"})
    df = pd.DataFrame.from_records(records)
    result_path = f"./results/{args.dataset}_{args.method}_tradeoff.csv"
    df.to_csv(result_path, index=False)
    print(f"Results saved to {result_path}")

    # Plot tradeoffs
    os.makedirs("./plots", exist_ok=True)
    cmap = plt.get_cmap('tab10')
    colors = cmap.colors
    fig, ax = plt.subplots()

    # Validation curves
    color_val = colors[0]
    ax.plot(cov_val, acc_val, '-', label='Val achieved', color=color_val)
    ax.plot(cov_val, bound_val, '--', label='Val bound', color=color_val)
    area_val = np.trapz(bound_val - acc_val, cov_val)
    ax.fill_between(cov_val, acc_val, bound_val, where=(bound_val >= acc_val), color=color_val, alpha=0.3,
                    label=f'Val gap area = {area_val:.4f}')

    # Test curves
    if loader_test:
        color_test = colors[1]
        ax.plot(cov_test, acc_test, '-', label='Test achieved', color=color_test)
        ax.plot(cov_test, bound_test, '--', label='Test bound', color=color_test)
        area_test = np.trapz(bound_test - acc_test, cov_test)
        ax.fill_between(cov_test, acc_test, bound_test, where=(bound_test >= acc_test), color=color_test, alpha=0.3,
                        label=f'Test gap area = {area_test:.4f}')

    ax.set_xlabel('Coverage')
    ax.set_ylabel('Selective Accuracy')
    ax.set_title(f'Selective Accuracy vs Coverage ({args.dataset}, {args.method})')
    ax.legend()
    fig.tight_layout()

    plot_path = f'./plots/{args.dataset}_{args.method}_tradeoff.png'
    fig.savefig(plot_path, dpi=300)
    print(f'Plot saved to {plot_path}')

if __name__ == "__main__":
    main()
