import argparse
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import matplotlib.pyplot as plt
from train_main import get_model
from torchvision.datasets import CIFAR10


def evaluate_split(model, loader, device):
    """
    Runs the model on all examples in loader and returns per-sample confidences and correctness flags.
    """
    model.eval()
    confidences, corrects = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            conf, preds = probs.max(dim=1)
            confidences.append(conf.cpu().numpy())
            corrects.append((preds == y).cpu().numpy().astype(float))
    return np.concatenate(confidences), np.concatenate(corrects)


def selective_accuracy(confidences, corrects):
    """
    Computes coverage and selective accuracy from confidences and correctness arrays.
    """
    idx = np.argsort(-confidences)
    sorted_correct = corrects[idx]
    n = len(sorted_correct)
    coverage = np.arange(1, n+1) / n
    acc = np.cumsum(sorted_correct) / np.arange(1, n+1)
    return coverage, acc


def perfect_upper_bound(a_full, coverage):
    """
    Computes the perfect-ordering upper bound given full-coverage accuracy and coverage array.
    """
    return np.where(coverage <= a_full, 1.0, a_full / coverage)


class CIFAR10CSeverityDataset(Dataset):
    """
    Mixed CIFAR-10-C dataset for one severity level.
    Randomly assigns one of the corruption types to each test index.
    """
    def __init__(self, cifarc_dir, severity, transform, seed=None):
        # Optionally seed randomness for reproducibility
        if seed is not None:
            np.random.seed(seed)

        # Load all labels (may be 10000 or 50000 long)
        full_labels = np.load(os.path.join(cifarc_dir, 'labels.npy'))

        # Gather severity slices for each corruption file
        files = sorted(f for f in os.listdir(cifarc_dir)
                       if f.endswith('.npy') and f != 'labels.npy')
        subsets = []
        for fname in files:
            arr = np.load(os.path.join(cifarc_dir, fname))
            # arr shape: (5*10000, H, W, C)
            subsets.append(arr[(severity-1)*10000 : severity*10000])
        self.subsets = subsets
        n = subsets[0].shape[0]  # expected 10000

        # Slice or validate labels to match severity slice length
        if full_labels.shape[0] == n:
            self.labels = full_labels
        elif full_labels.shape[0] % n == 0:
            start = (severity-1) * n
            end = severity * n
            self.labels = full_labels[start:end]
        else:
            raise ValueError(f"labels.npy length {full_labels.shape[0]} not compatible with severity chunk size {n}")

        # Number of corruption types
        self.n_corrupt = len(subsets)
        # Randomly choose a corruption per image index
        self.choice = np.random.randint(0, self.n_corrupt, size=n)

        self.transform = transform
        self.to_pil = T.ToPILImage()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        corrupt_idx = self.choice[idx]
        img_arr = self.subsets[corrupt_idx][idx]
        img = self.to_pil(img_arr)
        img = self.transform(img)
        return img, int(self.labels[idx])


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate CIFAR-10-C selective accuracy tradeoff"
    )
    parser.add_argument("--arch", choices=["simple_cnn","resnet18","wideresnet"],
                        required=True, help="Model architecture")
    parser.add_argument("--method", choices=["msp","sat"], default="msp",
                        help="Selective prediction method")
    parser.add_argument("--seed", type=int, default=0,
                        help="Random seed for mixing severities")
    parser.add_argument("--ckpt_dir", required=True,
                        help="Directory containing trained CIFAR-10 checkpoints")
    parser.add_argument("--data_dir", required=True,
                        help="Directory containing CIFAR-10 test set data")
    parser.add_argument("--cifar10c_dir", required=True,
                        help="Directory containing CIFAR-10-C .npy files")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    args = parser.parse_args()

    device = torch.device(args.device)

    # Load model trained on CIFAR-10 (clean)
    model = get_model(args.arch, "cifar10", extra_class=(args.method=="sat")).to(device)
    ckpt_path = os.path.join(
        args.ckpt_dir,
        f"{args.arch}_cifar10_{args.method}_seed{args.seed}.pt"
    )
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state)

    # Prepare transform for CIFAR-10/-C
    mean, std = (0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616)
    transform = T.Compose([T.ToTensor(), T.Normalize(mean, std)])

    # Ensure output dirs
    os.makedirs("./results", exist_ok=True)
    os.makedirs("./plots", exist_ok=True)

    # Evaluate clean CIFAR-10 test set
    clean_ds = CIFAR10(args.data_dir, train=False, transform=transform, download=True)
    clean_loader = DataLoader(clean_ds, batch_size=args.batch_size, shuffle=False)
    conf_clean, corr_clean = evaluate_split(model, clean_loader, device)
    cov_clean, acc_clean = selective_accuracy(conf_clean, corr_clean)
    bound_clean = perfect_upper_bound(corr_clean.mean(), cov_clean)

    # Prepare for plotting and recording
    records = []
    plt.figure()
    cmap = plt.get_cmap('tab10')
    colors = cmap.colors

    # Plot clean
    records += [{"split":"clean","severity":0,
                 "coverage":c,"achieved":a,"bound":b}
                for c,a,b in zip(cov_clean, acc_clean, bound_clean)]
    plt.plot(cov_clean, acc_clean, '-', label='Clean achieved', color=colors[0])
    plt.plot(cov_clean, bound_clean, '--', label='Clean bound', color=colors[0])

    # Evaluate and plot each CIFAR-10-C severity
    for sev in range(1,6):
        ds_sev = CIFAR10CSeverityDataset(
            args.cifar10c_dir, severity=sev, transform=transform, seed=args.seed)
        loader_sev = DataLoader(ds_sev, batch_size=args.batch_size, shuffle=False)
        conf_sev, corr_sev = evaluate_split(model, loader_sev, device)
        cov_sev, acc_sev = selective_accuracy(conf_sev, corr_sev)
        bound_sev = perfect_upper_bound(corr_sev.mean(), cov_sev)

        records += [{"split":"cifar10c","severity":sev,
                     "coverage":c,"achieved":a,"bound":b}
                    for c,a,b in zip(cov_sev, acc_sev, bound_sev)]
        col = colors[sev % len(colors)]
        plt.plot(cov_sev, acc_sev, '-', label=f'Sev {sev} achieved', color=col)
        plt.plot(cov_sev, bound_sev, '--', label=f'Sev {sev} bound', color=col)

    # Finalize plot
    plt.xlabel('Coverage')
    plt.ylabel('Selective Accuracy')
    plt.title(f'CIFAR-10 vs CIFAR-10-C Tradeoffs ({args.method})')
    plt.legend()
    plt.tight_layout()
    plot_path = f"./plots/cifar10c_{args.method}_tradeoff.png"
    plt.savefig(plot_path, dpi=300)
    print(f"Plot saved to {plot_path}")

    # Save results DataFrame
    df = pd.DataFrame.from_records(records)
    res_path = f"./results/cifar10c_{args.method}_tradeoff.csv"
    df.to_csv(res_path, index=False)
    print(f"Results saved to {res_path}")

if __name__ == '__main__':
    main()
