import argparse
import torch
import time
import os
import sys
import torch.nn as nn
import torchvision.models as models

sys.path.append(os.path.dirname(os.path.dirname(
    os.path.dirname(os.path.abspath(__file__)))))
from utils import print_log_info
from methods.our.unlearning import OurUnlearning
from models.lenet import LeNet
from models.resnet import ResNet9, CifarResNet18

def parse_args():
    """Parse command line arguments for model unlearning experiment"""
    parser = argparse.ArgumentParser(description='Model unlearning experiment')

    parser.add_argument('--model', type=str, required=True,
                        choices=['resnet9', 'resnet18', 'allcnn', 'lenet'],
                        help='Model architecture')
    parser.add_argument('--dataset', type=str, required=True,
                        choices=['mnist', 'cifar10', 'cifar100', 'svhn', 'tinyimagenet'],
                        help='Dataset')
    parser.add_argument('--model_path', type=str, required=True,
                        help='Pre-trained model path')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Batch size')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--device', type=str, default=None,
                        help='Device, e.g. "cuda:0" or "cpu"')

    forget_group = parser.add_mutually_exclusive_group(required=True)
    forget_group.add_argument('--class_idx', type=int,
                              help='Single class index to forget')
    forget_group.add_argument('--class_idxs', type=str,
                              help='Multiple class indices to forget, comma separated (will be forgotten sequentially)')
    forget_group.add_argument('--all', action='store_true',
                              help='Forget all classes')

    parser.add_argument('--target_layers', type=str, default=None,
                        help='Target layers, comma separated')
    parser.add_argument('--sens_source', type=str, default='noise',
                        choices=['noise', 'sample', 'hybrid'],
                        help='Parameter sensitivity calculation method')
    parser.add_argument('--skip_noise', action='store_true',
                        help='Skip noise generation step')
    parser.add_argument('--alpha_min', type=float, default=None,
                        help='Adaptive alpha minimum value')
    parser.add_argument('--alpha_max', type=float, default=None,
                        help='Adaptive alpha maximum value')
    parser.add_argument('--lambda_value', type=float, default=10.0,
                        help='Lambda parameter value for controlling unlearning strength')

    parser.add_argument('--log_dir', type=str, default=None,
                        help='Log directory, if not specified, will use default directory')
    parser.add_argument('--verbose', action='store_true',
                        help='Show verbose output')

    return parser.parse_args()


def get_model(model_name, dataset_name):
    """Create model instance based on model name and dataset name"""
    in_channels = 1 if dataset_name == 'mnist' else 3

    if dataset_name == 'cifar100':
        num_classes = 100
    elif dataset_name == 'tinyimagenet':
        num_classes = 200
    else:
        num_classes = 10

    if model_name == 'resnet9':
        return ResNet9(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'resnet18':
        if dataset_name == 'tinyimagenet':
            model = models.resnet18(weights=None)
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, num_classes)
            return model
        elif dataset_name == 'cifar100':
            return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
        else:
            return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'lenet':
        return LeNet(num_classes=num_classes, in_channels=in_channels)
    else:
        raise ValueError(f"Unsupported model: {model_name}")


def main():
    args = parse_args()

    if args.device:
        device = torch.device(args.device)
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Using device: {device}")

    print(f"Creating {args.model} model for {args.dataset} dataset")
    model = get_model(args.model, args.dataset)
    print(
        f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    if args.class_idx is not None:
        class_indices = [args.class_idx]
        print(f"Will forget single class: {args.class_idx}")
    elif args.class_idxs:
        class_indices = [int(idx) for idx in args.class_idxs.split(',')]
        print(f"Will sequentially forget multiple classes: {class_indices}")
    elif args.all:
        class_indices = None
        print(f"Will forget all classes")

    target_layers = args.target_layers.split(
        ',') if args.target_layers else None

    alpha_range = (
        args.alpha_min, args.alpha_max) if args.alpha_min is not None and args.alpha_max is not None else None

    if args.log_dir:
        log_dir = args.log_dir
    else:
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        forget_type = "single" if args.class_idx is not None else "sequential" if args.class_idxs else "all"
        log_dir = f"logs/our/{args.dataset}/{args.model}/{forget_type}_{timestamp}"

    print(f"Initializing unlearning framework...")
    unlearner = OurUnlearning(
        model=model,
        dataset_name=args.dataset,
        checkpoint_path=args.model_path,
        batch_size=args.batch_size,
        device=device,
        log_dir=log_dir,
        target_layers=target_layers,
        sens_source=args.sens_source,
        seed=args.seed,
        alpha_range=alpha_range
    )

    unlearner.lambda_value = args.lambda_value

    unlearner.skip_noise = args.skip_noise

    print(f"Starting unlearning process...")
    results = unlearner.unlearn_classes(class_indices)

    return results


if __name__ == "__main__":
    main()
