import argparse
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as T
import matplotlib.pyplot as plt
from train_cifar_n import get_model, make_transforms
from torchvision.datasets import CIFAR10, CIFAR100


def evaluate_split(model, loader, device):
    """
    Evaluate model on loader, returning per-sample confidences and correctness flags.
    """
    model.eval()
    confidences, corrects = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            conf, preds = probs.max(dim=1)
            confidences.append(conf.cpu().numpy())
            corrects.append((preds == y).cpu().numpy().astype(float))
    return np.concatenate(confidences), np.concatenate(corrects)


def selective_accuracy(confidences, corrects):
    """
    Compute coverage vs. selective accuracy curve.
    """
    idx = np.argsort(-confidences)
    sorted_corr = corrects[idx]
    n = len(sorted_corr)
    coverage = np.arange(1, n+1) / n
    acc = np.cumsum(sorted_corr) / np.arange(1, n+1)
    return coverage, acc


def perfect_upper_bound(a_full, coverage):
    """
    Perfect-ordering upper bound as function of full-coverage accuracy and coverage.
    """
    return np.where(coverage <= a_full, 1.0, a_full / coverage)


def main():
    parser = argparse.ArgumentParser(
        description="Accuracy-coverage on most noisy training points (CIFAR-10N/100N)"
    )
    parser.add_argument('--dataset', choices=['cifar10','cifar100'], required=True,
                        help='Choose cifar10 or cifar100')
    parser.add_argument('--arch', choices=['simple_cnn','resnet18','wideresnet'], required=True)
    parser.add_argument('--method', choices=['msp','sat'], default='msp')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--ckpt_dir', required=True,
                        help='Path to checkpoints: <ckpt_dir>/<dataset>_n/...')
    parser.add_argument('--data_dir', required=True,
                        help='Root dir for CIFAR dataset')
    parser.add_argument('--label_pt', required=True,
                        help='Path to CIFAR-10_human.pt or CIFAR-100_human.pt')
    parser.add_argument('--indices_dir', default='./indices',
                        help='Directory where holdout index files are saved')
    parser.add_argument('--percentages', nargs='+', type=float, default=[10,20,50],
                        help='Noise-percentages (%%) to evaluate')
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
    args = parser.parse_args()

    # reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(args.device)

    # load human labels
    noise_dict = torch.load(args.label_pt, weights_only=False)
    clean = np.array(noise_dict['clean_label'], dtype=int)
    if args.dataset == 'cifar10':
        worse = np.array(noise_dict['worse_label'], dtype=int)
        r1 = np.array(noise_dict['random_label1'], dtype=int)
        r2 = np.array(noise_dict['random_label2'], dtype=int)
        r3 = np.array(noise_dict['random_label3'], dtype=int)
        disagreement = ((worse != clean).astype(int)
                      + (r1 != clean).astype(int)
                      + (r2 != clean).astype(int)
                      + (r3 != clean).astype(int))
    else:
        noisy = np.array(noise_dict['noisy_label'], dtype=int)
        disagreement = (noisy != clean).astype(int)

        # load held-out indices
    idx_path = os.path.join(args.indices_dir,
                        f"{args.dataset}_n_holdout_{args.method}_seed{args.seed}.npy")
    held_idx = np.load(idx_path)

    # sort by disagreement within held-out
    sorted_local = held_idx[np.argsort(-disagreement[held_idx])]
    n_local = len(sorted_local)
    pct2idx = {pct: sorted_local[:int(n_local*pct/100)] for pct in args.percentages}

    # load model
    model = get_model(args.arch, args.dataset+'_n', extra_class=(args.method=='sat')).to(device)
    ckpt = os.path.join(args.ckpt_dir, args.dataset+'_n',
                        f"{args.arch}_{args.dataset}_n_{args.method}_seed{args.seed}.pt")
    model.load_state_dict(torch.load(ckpt, map_location=device))

    # transforms
    _, transform = make_transforms(args.dataset)

            # load full training data
    base = args.dataset
    if base == 'cifar10':
        full = CIFAR10(args.data_dir, train=True, transform=transform, download=True)
    else:
        full = CIFAR100(args.data_dir, train=True, transform=transform, download=True)

    # prepare outputs
    os.makedirs('./results', exist_ok=True)
    os.makedirs('./plots', exist_ok=True)

    plt.figure()
    cmap = plt.get_cmap('tab10')
    colors = cmap.colors

    records = []
    for i, pct in enumerate(args.percentages):
        idx_list = pct2idx[pct]
        subset = Subset(full, idx_list)
        loader = DataLoader(subset, batch_size=args.batch_size, shuffle=False)

        conf, corr = evaluate_split(model, loader, device)
        cov, acc = selective_accuracy(conf, corr)
        bound = perfect_upper_bound(corr.mean(), cov)
        gap_area = np.trapz(bound - acc, cov)

        for c,a,b in zip(cov, acc, bound):
            records.append({'percent':pct, 'coverage':c, 'achieved':a,
                            'bound':b, 'gap_area':gap_area})

        color = colors[i % len(colors)]
        plt.plot(cov, acc, '-', label=f'{int(pct)}% achieved', color=color)
        plt.plot(cov, bound, '--', label=f'{int(pct)}% bound', color=color)

    plt.xlabel('Coverage')
    plt.ylabel('Selective Accuracy')
    plt.title(f'Accuracy–Coverage on top noisy points of {args.dataset}')
    plt.legend()
    plt.tight_layout()
    plot_path = f'./plots/{args.dataset}_noisy_tradeoff.png'
    plt.savefig(plot_path, dpi=300)
    print(f'Saved plot to {plot_path}')

    # save to CSV
    df = pd.DataFrame.from_records(records)
    csv_path = f'./results/{args.dataset}_noisy_tradeoff.csv'
    df.to_csv(csv_path, index=False)
    print(f'Saved results to {csv_path}')

if __name__ == '__main__':
    main()
