import argparse
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import numpy as np
from PIL import Image
from utils.dataset import load_datasets
from utils.utils import load_model, load_checkpoint, construct_checkpoint_name
from utils.config import *

class EnsembleModel(nn.Module):
    def __init__(self, models):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList(models)

    def forward(self, x):
        probs_list = []
        for model in self.models:
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            probs_list.append(probs)
        
        avg_probs = torch.stack(probs_list).mean(dim=0)
        
        return torch.log(avg_probs + 1e-8)

def evaluate_checkpoint(model, device, test_loader, num_classes, n_bins_static=20, n_bins_adaptive=20, thresh=0.0):
    model.eval()

    correct, count = 0, 0
    nll_total = 0.0

    all_probs        = [] 
    all_conf_max     = []   
    all_pred_max     = []   
    all_targets      = [] 

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            logits = model(data)                    
            probs  = F.softmax(logits, dim=1)      

            conf_max, pred_max = torch.max(probs, dim=1)

            correct  += (pred_max == target).sum().item()
            nll_total += F.nll_loss(F.log_softmax(logits, dim=1),
                                    target,
                                    reduction='sum').item()
            count    += target.size(0)

            all_probs.append(probs.cpu())
            all_conf_max.append(conf_max.cpu())
            all_pred_max.append(pred_max.cpu())
            all_targets.append(target.cpu())

    if len(all_probs) == 0:
        return float('nan'), float('nan'), float('nan'), float('nan'), float('nan')

    probs      = torch.cat(all_probs, dim=0)
    conf_max   = torch.cat(all_conf_max, dim=0)
    pred_max   = torch.cat(all_pred_max, dim=0)
    targets    = torch.cat(all_targets, dim=0)
    N          = probs.size(0)
    K          = int(num_classes)
    R          = n_bins_adaptive
    
    accuracy = 100. * correct / count if count else float('nan')
    nll      = nll_total / count    if count else float('nan')

    ece = 0.0
    bin_boundaries = torch.linspace(0., 1., n_bins_static + 1)
    for b in range(n_bins_static):
        lo, hi = bin_boundaries[b].item(), bin_boundaries[b + 1].item()
        in_bin = (conf_max > lo) & (conf_max <= hi)
        prop   = float(in_bin.float().mean().item())
        if prop > 0:
            acc_bin  = float((pred_max[in_bin] == targets[in_bin]).float().mean().item())
            conf_bin = float(conf_max[in_bin].mean().item())
            ece     += abs(conf_bin - acc_bin) * prop

    sce = 0.0
    for k in range(K):
        class_probs = probs[:, k]                    
        for b in range(n_bins_static):
            lo, hi = bin_boundaries[b].item(), bin_boundaries[b + 1].item()
            in_bin = (class_probs > lo) & (class_probs <= hi)
            nbk    = int(in_bin.sum().item())
            if nbk > 0:
                acc_bk  = float((targets[in_bin] == k).float().mean().item())
                conf_bk = float(class_probs[in_bin].mean().item())
                sce    += nbk * abs(acc_bk - conf_bk)
    sce /= (K * N)

    conf_max, pred_max = probs.max(dim=1) 
    N = conf_max.size(0)
    if N == 0 or R <= 0:
        return float('nan')

    sorted_conf, sorted_idx = torch.sort(conf_max)
    ace_sum = 0.0
    for r in range(R):
        start = int(r * N / R)
        end = int((r + 1) * N / R)
        if start >= end:
            continue
        idx_r = sorted_idx[start:end]
        conf_r = float(conf_max[idx_r].mean().item())
        acc_r = float((pred_max[idx_r] == targets[idx_r]).float().mean().item())
        ace_sum += abs(acc_r - conf_r)
    ace = ace_sum / R

    return accuracy, ece, nll, sce, ace


class CIFAR10Corrupted(Dataset):
    def __init__(self, data_dir, corruption_type, severity, transform=None):
        self.data_dir = data_dir
        self.corruption_type = corruption_type
        self.severity = severity
        self.transform = transform

        corruption_file_path = os.path.join(data_dir, f'{corruption_type}.npy')
        labels_file_path = os.path.join(data_dir, 'labels.npy')

        if not os.path.exists(corruption_file_path):
            raise FileNotFoundError(f"Corruption file not found: {corruption_file_path}")
        if not os.path.exists(labels_file_path):
            raise FileNotFoundError(f"Labels file not found: {labels_file_path}")

        all_images = np.load(corruption_file_path)
        all_labels = np.load(labels_file_path)

        start_index = (severity - 1) * 10000
        end_index = start_index + 10000

        self.images = all_images[start_index:end_index]
        self.labels = all_labels[start_index:end_index]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label, dtype=torch.long)

        return image, label


class CIFAR10Perturbed(Dataset):
    def __init__(self, data_dir, perturbation_type, transform=None):
        self.data_dir = data_dir
        self.perturbation_type = perturbation_type
        self.transform = transform

        perturbation_file_path = os.path.join(data_dir, f'{perturbation_type}.npy')

        if not os.path.exists(perturbation_file_path):
            raise FileNotFoundError(f"Perturbation file not found: {perturbation_file_path}")


        self.images = np.load(perturbation_file_path)

        self._original_shape = self.images.shape
        self.num_steps = self._original_shape[1]
        self.num_images_per_step = self._original_shape[0]
        self.images = self.images.reshape(-1, self._original_shape[2], self._original_shape[3], self._original_shape[4])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]

        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image

def load_test_dataset(args):
    dataset = args.dataset.lower()
    data_dir = DATA_CONFIG[dataset]['path']

    if dataset not in ['cifar10p', 'cifar10c']:  
        _, _, test_dataset = load_datasets(args)
        num_classes = DATA_CONFIG[args.dataset.lower()]['num_classes']
        return [('standard', test_dataset)], num_classes
    elif dataset == 'cifar10c':
        corruption_types = [
            'brightness', 'contrast', 'defocus_blur', 'elastic_transform',
            'fog', 'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur',
            'impulse_noise', 'jpeg_compression', 'motion_blur', 'pixelate',
            'saturate', 'shot_noise', 'snow', 'spatter', 'speckle_noise',
            'zoom_blur'
        ]
        severity_levels = range(1, 6)
        datasets_info = []
        for c_type in corruption_types:
            for severity in severity_levels:
                datasets_info.append(('cifar10c', c_type, severity, data_dir))
        return datasets_info, 10
    elif dataset == 'cifar10p':
        perturbation_types = [
            'brightness', 'gaussian_blur', 'gaussian_noise', 'gaussian_noise_2',
            'gaussian_noise_3', 'motion_blur', 'rotate', 'scale', 'shear',
            'shot_noise', 'shot_noise_2', 'shot_noise_3', 'snow', 'spatter',
            'speckle_noise', 'speckle_noise_2', 'speckle_noise_3', 'tilt',
            'translate', 'zoom_blur'
        ]
        datasets_info = []
        for p_type in perturbation_types:
            datasets_info.append(('cifar10p', p_type, data_dir))
        return datasets_info, 10
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate model: print model info, accuracy, ECE and NLL on one line")
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='Dataset (default: CIFAR10)')
    parser.add_argument('--model', type=str, default='ResNet18', help='Model architecture (default: ResNet18)')
    parser.add_argument('--method', type=str, default='bayesian', help='Method (bayesian, baseline, temporal, dlb, pskd)')
    parser.add_argument('--checkpoints', nargs='+', type=str, help='List of strings')
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size (default: 256)')
    parser.add_argument('--num_workers', type=int, default=12, help='Number of dataloader workers (default: 4)')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device (default: cuda if available)')
    parser.add_argument('--n_bins', type=int, default=15, help='Number of bins for ECE (default: 20)')
    parser.add_argument('--ens', action='store_true', help='Treat checkpoints as an ensemble (average probabilities)')
    return parser.parse_args()

def main():
    args = parse_args()
    
    # Print model and method information
    print(f"Model: {args.model}")
    print(f"Method: {args.method}")
    print()
    
    device = torch.device(args.device)
    cfg,  is_vit = DATA_CONFIG[args.dataset.lower()], "vit" in args.model.lower()

    test_transform_cifar10 = transforms.Compose([
            *([transforms.Resize((224, 224))] if is_vit else []),
            transforms.ToTensor(),
            *([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])] if is_vit else [
                transforms.Normalize(mean=cfg['mean'], std=cfg['std'])])
        ])
    datasets_info, num_classes = load_test_dataset(args)

    loaded_models = []
    for ckpt in args.checkpoints:
        model = load_model(args.model, num_classes)
        model.to(device)
        checkpoint = torch.load(ckpt, map_location=device)

        state_dict = checkpoint['state_dict']
        if any(k.startswith('_orig_mod.') for k in state_dict.keys()):
            state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
        if any(k.startswith('module.') for k in state_dict.keys()):
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

        model.load_state_dict(state_dict)
        model.eval()
        loaded_models.append(model)

    if args.ens:
        if len(loaded_models) < 2:
            print("Warning: Ensembling enabled but fewer than 2 models provided.")
        
        print(f"Ensembling {len(loaded_models)} checkpoints...")
        ensemble = EnsembleModel(loaded_models)
        ensemble.eval()
        ensemble.to(device)
        
        models_to_evaluate = [ensemble]
    else:
        models_to_evaluate = loaded_models

    if args.dataset.lower() in ['cifar10c', 'cifar10p']:
        print(f"Evaluating on {args.dataset}")

        all_accuracies = []
        all_ece_values = []
        all_nll_values = []
        all_flip_probs = []

        for dataset_info in datasets_info:
            dataset_type = dataset_info[0]
            if dataset_type == 'cifar10c':
                _, c_type, severity, data_dir = dataset_info
                print(f"  Loading {c_type} at severity {severity}")
                try:
                    test_dataset = CIFAR10Corrupted(data_dir=data_dir, corruption_type=c_type, severity=severity, transform=test_transform_cifar10)
                    variant_name = f"{c_type} severity {severity}"
                    is_perturbed = False
                except FileNotFoundError as e:
                    print(f"  Skipping {c_type} severity {severity}: {e}")
                    continue
            elif dataset_type == 'cifar10p':
                _, p_type, data_dir = dataset_info
                print(f"  Loading {p_type}")
                try:
                    test_dataset = CIFAR10Perturbed(data_dir=data_dir, perturbation_type=p_type, transform=test_transform_cifar10)
                    variant_name = f"{p_type}"
                    is_perturbed = True
                    num_steps = test_dataset.num_steps
                    num_images_per_step = test_dataset.num_images_per_step
                except FileNotFoundError as e:
                    print(f"  Skipping {p_type}: {e}")
                    continue
            else:
                raise ValueError(f"Unknown dataset type in datasets_info: {dataset_type}")

            test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

            accuracies_variant = []
            ece_values_variant = []
            nll_values_variant = []
            flip_probs_variant = []

            for model in models_to_evaluate:
                if is_perturbed:
                    model.eval()
                    all_preds_for_flip = []
                    with torch.no_grad():
                        for data in test_loader:
                            data = data.to(device)
                            outputs = model(data)
                            _, preds = torch.max(F.softmax(outputs, dim=1), dim=1)
                            all_preds_for_flip.append(preds.cpu())
                    preds_tensor = torch.cat(all_preds_for_flip)

                    predictions_reshaped = preds_tensor.view(num_images_per_step, num_steps).t()

                    adjacent_flips = (predictions_reshaped[1:] != predictions_reshaped[:-1]).float()
                    flip_prob = adjacent_flips.mean().item()

                    flip_probs_variant.append(flip_prob)
                else:
                    acc, ece, nll, _, _ = evaluate_checkpoint(model, device, test_loader, num_classes)
                    accuracies_variant.append(acc)
                    ece_values_variant.append(ece)
                    nll_values_variant.append(nll)

            if is_perturbed:
                if flip_probs_variant:
                     avg_flip = np.mean(flip_probs_variant)
                     std_flip = np.std(flip_probs_variant)
                     print(f"    {variant_name} | Flip Prob: {avg_flip:.4f} ± {std_flip:.4f}")
                     all_flip_probs.extend(flip_probs_variant)
                else:
                     print(f"    No evaluation results for {variant_name}")
            else:
                if accuracies_variant:
                    avg_acc_variant = np.mean(accuracies_variant)
                    std_acc_variant = np.std(accuracies_variant)
                    avg_ece_variant = np.mean(ece_values_variant)
                    std_ece_variant = np.std(ece_values_variant)
                    avg_nll_variant = np.mean(nll_values_variant)
                    std_nll_variant = np.std(nll_values_variant)
                    print(f"    {variant_name} | Acc: {avg_acc_variant:.2f} ± {std_acc_variant:.2f} | ECE: {avg_ece_variant:.4f} ± {std_ece_variant:.4f} | NLL: {avg_nll_variant:.4f} ± {std_nll_variant:.4f}")

                    all_accuracies.extend(accuracies_variant)
                    all_ece_values.extend(ece_values_variant)
                    all_nll_values.extend(nll_values_variant)
                else:
                    print(f"    No evaluation results for {variant_name}")

        if args.dataset.lower() == 'cifar10p':
            if all_flip_probs:
                 overall_avg_flip = np.mean(all_flip_probs)
                 overall_std_flip = np.std(all_flip_probs)
                 print(f"\nOverall {args.dataset.lower()} {args.model.lower()} {args.method.lower()} | Flip Prob: {overall_avg_flip:.4f} ± {overall_std_flip:.4f}")
            else:
                 print(f"No evaluation results across all variants for {args.model} {args.method} {args.dataset}")
        elif args.dataset.lower() == 'cifar10c':
             if all_accuracies:
                overall_avg_acc = np.mean(all_accuracies)
                overall_std_acc = np.std(all_accuracies)
                overall_avg_ece = np.mean(all_ece_values)
                overall_std_ece = np.std(all_ece_values)
                overall_avg_nll = np.mean(all_nll_values)
                overall_std_nll = np.std(all_nll_values)
                print(f"\nOverall {args.dataset.lower()} {args.model.lower()} {args.method.lower()} | Acc: {overall_avg_acc:.2f} ± {overall_std_acc:.2f} | ECE: {overall_avg_ece:.4f} ± {overall_std_ece:.4f} | NLL: {overall_avg_nll:.4f} ± {overall_std_nll:.4f}")
             else:
                 print(f"No evaluation results across all variants for {args.model} {args.method} {args.dataset}")

    else:
        test_dataset = datasets_info[0][1]
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
        accuracies = []
        ece_values = []
        sce_values = []
        ace_values = []
        nll_values = []
        for model in models_to_evaluate:
            acc, ece, nll, sce, ace = evaluate_checkpoint(model, device, test_loader, num_classes, n_bins_static=args.n_bins, n_bins_adaptive=args.n_bins)
            accuracies.append(acc)
            ece_values.append(ece)
            sce_values.append(sce)
            ace_values.append(ace)
            nll_values.append(nll)

        if accuracies:
            avg_acc = np.mean(accuracies)
            std_acc = np.std(accuracies)
            avg_ece = np.mean(ece_values)
            std_ece = np.std(ece_values)
            avg_ace = np.mean(ace_values)
            std_ace = np.std(ace_values)
            avg_sce = np.mean(sce_values)
            std_sce = np.std(sce_values)
            avg_nll = np.mean(nll_values)
            std_nll = np.std(nll_values)
            print(f"{args.dataset.lower()} {args.model.lower()} {args.method.lower()} | Acc: {avg_acc:.3f} ± {std_acc:.3f} | ECE: {avg_ece:.6f} ± {std_ece:.6f} | SCE: {avg_sce:.6f} ± {std_sce:.6f} | ACE: {avg_ace:.6f} ± {std_ace:.6f} | NLL: {avg_nll:.4f} ± {std_nll:.4f}")
        else:
             print(f"No evaluation results for {args.model} {args.method} {args.dataset}")


if __name__ == "__main__":
    main()