import argparse
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt
from train_main import get_model, make_transforms
from torchvision.datasets import CIFAR10, CIFAR100, StanfordCars


def evaluate_split(model, loader, device, temperature=1.0):
    """
    Runs a model on the loader, applying optional temperature scaling, and returns confidences and correctness.
    """
    model.eval()
    confidences, corrects = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            # apply temperature scaling
            scaled_logits = logits / temperature
            probs = F.softmax(scaled_logits, dim=1)
            conf, preds = probs.max(dim=1)
            confidences.append(conf.cpu().numpy())
            corrects.append((preds == y).cpu().numpy().astype(float))
    return np.concatenate(confidences), np.concatenate(corrects)


def evaluate_ensemble(models, loader, device):
    """
    Runs an ensemble of models, averaging softmax probabilities.
    """
    for m in models:
        m.eval()
    confidences, corrects = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            probs_sum = None
            for m in models:
                logits = m(x)
                probs = F.softmax(logits, dim=1)
                probs_sum = probs if probs_sum is None else probs_sum + probs
            probs_avg = probs_sum / len(models)
            conf, preds = probs_avg.max(dim=1)
            confidences.append(conf.cpu().numpy())
            corrects.append((preds == y).cpu().numpy().astype(float))
    return np.concatenate(confidences), np.concatenate(corrects)


def selective_accuracy(confidences, corrects):
    idx = np.argsort(-confidences)
    sorted_correct = corrects[idx]
    n = len(sorted_correct)
    coverage = np.arange(1, n+1) / n
    acc = np.cumsum(sorted_correct) / np.arange(1, n+1)
    return coverage, acc


def perfect_upper_bound(a_full, coverage):
    return np.where(coverage <= a_full, 1.0, a_full / coverage)


def get_dataset_loader(dataset_name, data_dir, transform):
    if dataset_name == 'cifar10':
        return CIFAR10(data_dir, train=False, transform=transform, download=True)
    if dataset_name == 'cifar100':
        return CIFAR100(data_dir, train=False, transform=transform, download=True)
    if dataset_name == 'stanfordcars':
        return StanfordCars(data_dir, split='test', transform=transform, download=False)
    raise ValueError(f"Unsupported dataset {dataset_name}")


def main():
    parser = argparse.ArgumentParser(
        description="Compare selective accuracy and calibration across architectures with MSP, SAT, DE, and Temperature scaling"
    )
    parser.add_argument('--dataset', choices=['cifar10','cifar100','stanfordcars'], required=True,
                        help='Dataset to evaluate')
    parser.add_argument('--archs', nargs='+', choices=['simple_cnn','resnet18','wideresnet'], required=True,
                        help='List of model architectures')
    parser.add_argument('--method', choices=['msp','sat','de','temp'], default='msp',
                        help='Method: msp, sat, de (deep ensemble), or temp (temperature scaling)')
    parser.add_argument('--seed', type=int, default=0,
                        help='Random seed for reproducibility')
    parser.add_argument('--ckpt_dir', required=True,
                        help='Root dir containing checkpoints: <dataset>/<arch>_<dataset>_<method>_seed<seed>.pt')
    parser.add_argument('--data_dir', required=True,
                        help='Directory containing dataset data')
    parser.add_argument('--ensemble_seeds', nargs='+', type=int,
                        help='Seeds for deep ensemble when --method de')
    parser.add_argument('--temperatures', nargs='+', type=float,
                        help='List of temperatures, one per architecture, when --method temp')
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
    args = parser.parse_args()

    # Validate temperatures list length if temp method
    if args.method == 'temp':
        if not args.temperatures or len(args.temperatures) != len(args.archs):
            raise ValueError('Please provide one --temperatures value per architecture')

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(args.device)

    transform = make_transforms(args.dataset)[1]
    os.makedirs('./results', exist_ok=True)
    os.makedirs('./plots', exist_ok=True)

    tradeoff_records = []
    reliability_records = []
    plt.figure()
    colors = plt.get_cmap('tab10').colors

    for i, arch in enumerate(args.archs):
        # Setup loader
        loader = DataLoader(get_dataset_loader(args.dataset, args.data_dir, transform),
                            batch_size=args.batch_size, shuffle=False)

        # Evaluate according to method
        if args.method == 'de':
            if not args.ensemble_seeds:
                raise ValueError('Specify --ensemble_seeds for deep ensemble')
            models = []
            for s in args.ensemble_seeds:
                m = get_model(arch, args.dataset, extra_class=False).to(device)
                path = os.path.join(args.ckpt_dir, args.dataset,
                                    f"{arch}_{args.dataset}_msp_seed{s}.pt")
                m.load_state_dict(torch.load(path, map_location=device))
                models.append(m)
            conf, corr = evaluate_ensemble(models, loader, device)
        else:
            m = get_model(arch, args.dataset, extra_class=(args.method=='sat')).to(device)
            method_tag = 'msp' if args.method in ['temp','de'] else args.method
            path = os.path.join(args.ckpt_dir, args.dataset,
                                f"{arch}_{args.dataset}_{method_tag}_seed{args.seed}.pt")
            m.load_state_dict(torch.load(path, map_location=device))
            temp = 1.0
            if args.method == 'temp':
                temp = args.temperatures[i]
            conf, corr = evaluate_split(m, loader, device, temperature=temp)

        # Tradeoff computations
        cov, acc = selective_accuracy(conf, corr)
        bound = perfect_upper_bound(corr.mean(), cov)
        gap_area = np.trapz(bound - acc, cov)
        for c,a,b in zip(cov, acc, bound):
            tradeoff_records.append({'arch': arch, 'coverage': c,
                                     'achieved': a, 'bound': b,
                                     'gap_area': gap_area, 'method': args.method})
        color = colors[i % len(colors)]
        plt.plot(cov, acc, '-', label=f'{arch} achieved (gap={gap_area:.4f})', color=color)
        plt.plot(cov, bound, '--', label=f'{arch} bound', color=color)

        # Reliability data
        n_bins = 10
        bins = np.linspace(0,1,n_bins+1)
        bin_idx = np.digitize(conf, bins) - 1
        for b in range(n_bins):
            mask = bin_idx == b
            if mask.sum() > 0:
                mean_conf = conf[mask].mean()
                mean_acc = corr[mask].mean()
                count = int(mask.sum())
            else:
                mean_conf, mean_acc, count = 0.0, 0.0, 0
            reliability_records.append({
                'arch': arch,
                'bin': b,
                'bin_confidence': mean_conf,
                'bin_accuracy': mean_acc,
                'bin_count': count,
                'method': args.method
            })

    # Plot and save tradeoff
    plt.xlabel('Coverage')
    plt.ylabel('Selective Accuracy')
    plt.title(f'Selective Accuracy vs Coverage on {args.dataset} ({args.method})')
    plt.legend()
    plt.tight_layout()
    trade_path = f'./plots/{args.dataset}_{args.method}_archs_tradeoff.png'
    plt.savefig(trade_path, dpi=300)
    print(f'Tradeoff plot saved to {trade_path}')

    df_trade = pd.DataFrame.from_records(tradeoff_records)
    trade_csv = f'./results/{args.dataset}_{args.method}_archs_tradeoff.csv'
    df_trade.to_csv(trade_csv, index=False)
    print(f'Tradeoff data saved to {trade_csv}')

    # Plot and save reliability
    plt.figure()
    df_rel = pd.DataFrame.from_records(reliability_records)
    for i, arch in enumerate(args.archs):
        df_arch = df_rel[df_rel['arch'] == arch]
        plt.plot(df_arch['bin_confidence'], df_arch['bin_accuracy'], marker='o',
                 label=arch, color=colors[i % len(colors)])
    plt.plot([0,1],[0,1],'--', color='gray', label='Ideal')
    plt.xlabel('Mean predicted confidence')
    plt.ylabel('Observed accuracy')
    plt.title(f'Reliability Diagram on {args.dataset} ({args.method})')
    plt.legend()
    plt.tight_layout()
    reliab_path = f'./plots/{args.dataset}_{args.method}_archs_reliability.png'
    plt.savefig(reliab_path, dpi=300)
    print(f'Reliability plot saved to {reliab_path}')

    reliab_csv = f'./results/{args.dataset}_{args.method}_archs_reliability.csv'
    pd.DataFrame.from_records(reliability_records).to_csv(reliab_csv, index=False)
    print(f'Reliability data saved to {reliab_csv}')

if __name__ == '__main__':
    main()