import argparse
import os
import torch
import torch.nn as nn
import numpy as np
from collections import defaultdict

from models.resnet import ResNet9, ResNet18, CifarResNet18
from models.lenet import LeNet
from datasets.load_datasets import load_dataset


def get_model(model_name, in_channels, num_classes):
    """Select the specified model."""
    if model_name == 'resnet9':
        return ResNet9(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'resnet18':
        if num_classes == 100:
            return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
        else:
            return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'lenet':
        return LeNet(num_classes=num_classes, in_channels=in_channels)
    else:
        raise ValueError(f"Unsupported model: {model_name}")


def get_dataset_info(dataset_name):
    """Get dataset information."""
    if dataset_name == 'mnist':
        return 1, 10
    elif dataset_name in ['cifar10', 'svhn']:
        return 3, 10
    elif dataset_name == 'cifar100':
        return 3, 100
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")


def split_test_data(test_loader, forget_class=0):
    """Split test data into forget and retain sets."""
    forget_data = []
    retain_data = []
    for inputs, targets in test_loader:
        forget_mask = (targets == forget_class)
        retain_mask = ~forget_mask
        if forget_mask.sum() > 0:
            forget_data.append((inputs[forget_mask], targets[forget_mask]))
        if retain_mask.sum() > 0:
            retain_data.append((inputs[retain_mask], targets[retain_mask]))
    return forget_data, retain_data


def evaluate_on_subset(model, data_subset, device, subset_name):
    """Evaluate the model on a data subset."""
    if not data_subset:
        print(f"{subset_name}: No data")
        return 0.0, 0, 0
    model.eval()
    correct = 0
    total = 0
    class_correct = defaultdict(int)
    class_total = defaultdict(int)
    with torch.no_grad():
        for inputs, targets in data_subset:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            for i in range(targets.size(0)):
                label = targets[i].item()
                pred = predicted[i].item()
                class_total[label] += 1
                if label == pred:
                    class_correct[label] += 1
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    print(f"{subset_name}:")
    print(f"  Overall accuracy: {accuracy:.2f}% ({correct}/{total})")
    if class_total:
        print("  Per-class accuracy:")
        for class_id in sorted(class_total.keys()):
            class_acc = 100.0 * class_correct[class_id] / class_total[class_id]
            print(f"    Class {class_id}: {class_acc:.2f}% ({class_correct[class_id]}/{class_total[class_id]})")
    return accuracy, correct, total


def check_model(model_name, dataset_name, checkpoint_dir="checkpoints"):
    """Check the performance of a single model."""
    print(f"\n{'='*60}")
    print(f"Check model: {model_name} + {dataset_name}")
    print(f"{'='*60}")
    model_path = os.path.join(checkpoint_dir, f"{model_name}_{dataset_name}_best.pth")
    if not os.path.exists(model_path):
        print(f"Model file does not exist: {model_path}")
        return None
    in_channels, num_classes = get_dataset_info(dataset_name)
    _, test_loader, _ = load_dataset(dataset_name, batch_size=128, num_workers=4)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model(model_name, in_channels, num_classes)
    checkpoint = torch.load(model_path, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    print(f"Model loaded, best accuracy: {checkpoint.get('best_acc', 'Unknown'):.2f}%")
    forget_data, retain_data = split_test_data(test_loader, forget_class=0)
    forget_acc, forget_correct, forget_total = evaluate_on_subset(
        model, forget_data, device, "Forget set (class 0)")
    retain_acc, retain_correct, retain_total = evaluate_on_subset(
        model, retain_data, device, "Retain set (other classes)")
    total_correct = forget_correct + retain_correct
    total_samples = forget_total + retain_total
    overall_acc = 100.0 * total_correct / total_samples if total_samples > 0 else 0.0
    print(f"\nOverall statistics:")
    print(f"  Total accuracy: {overall_acc:.2f}% ({total_correct}/{total_samples})")
    print(f"  Forget set samples: {forget_total}")
    print(f"  Retain set samples: {retain_total}")
    return {
        'model': model_name,
        'dataset': dataset_name,
        'forget_acc': forget_acc,
        'retain_acc': retain_acc,
        'overall_acc': overall_acc,
        'forget_total': forget_total,
        'retain_total': retain_total
    }


def main():
    """Main entry for checking trained model accuracy."""
    parser = argparse.ArgumentParser(description='Check accuracy of trained models')
    parser.add_argument('--checkpoint-dir', type=str, default='checkpoints',
                        help='Model checkpoint directory')
    parser.add_argument('--model', type=str, default=None,
                        help='Specify a particular model to check, format: model_dataset (e.g., resnet9_mnist)')
    args = parser.parse_args()
    combinations = [
        ("lenet", "mnist"),
        ("resnet9", "svhn"),
        ("resnet18", "cifar10"),
    ]
    results = []
    if args.model:
        try:
            model_name, dataset_name = args.model.split('_', 1)
            result = check_model(model_name, dataset_name, args.checkpoint_dir)
            if result:
                results.append(result)
        except ValueError:
            print(f"Invalid model format: {args.model}, should be model_dataset")
            return
    else:
        for model_name, dataset_name in combinations:
            result = check_model(model_name, dataset_name, args.checkpoint_dir)
            if result:
                results.append(result)
    if results:
        print(f"\n{'='*80}")
        print("Summary Report")
        print(f"{'='*80}")
        print(f"{'Model-Dataset':<20} {'Forget Acc':<12} {'Retain Acc':<12} {'Overall Acc':<10} {'Forget/Retain Samples':<15}")
        print("-" * 80)
        for result in results:
            model_dataset = f"{result['model']}_{result['dataset']}"
            forget_acc = f"{result['forget_acc']:.2f}%"
            retain_acc = f"{result['retain_acc']:.2f}%"
            overall_acc = f"{result['overall_acc']:.2f}%"
            sample_info = f"{result['forget_total']}/{result['retain_total']}"
            print(f"{model_dataset:<20} {forget_acc:<12} {retain_acc:<12} {overall_acc:<10} {sample_info:<15}")


if __name__ == "__main__":
    main()
