import argparse
import torch
import time
import os
import sys
import pandas as pd

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import print_log_info
from ablation.perturbation_unlearning import PerturbationUnlearning
from models.lenet import LeNet
from models.resnet import ResNet9, CifarResNet18


def parse_args():
    """Parse command line arguments for the ablation experiment."""
    parser = argparse.ArgumentParser(
        description='Ablation experiment for parameter selection strategy (Independent Unlearning)')

    parser.add_argument('--model', type=str, default='resnet18',
                        choices=['resnet18'], help='Model architecture (fixed to resnet18)')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        choices=['cifar10'], help='Dataset (fixed to cifar10)')
    parser.add_argument('--model_path', type=str, required=True,
                        help='Pre-trained model path')
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Batch size')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--device', type=str, default=None,
                        help='Device, e.g. "cuda:0" or "cpu"')

    parser.add_argument('--class_idxs', type=str, default="0,1,2,3,4,5,6,7,8,9",
                        help='Multiple class indices to forget, comma separated (will be forgotten independently)')

    parser.add_argument('--k_percent', type=float, default=0.3,
                        help='Percentage of top sensitive parameters to prune')

    parser.add_argument('--log_dir', type=str, default=None,
                        help='Log directory, if not specified, will use default directory')
    parser.add_argument('--results_dir', type=str, default='results/ablation',
                        help='Directory to save the final CSV results')
    parser.add_argument('--noise_data_dir', type=str, default='noise_data',
                        help='Directory to store or load pre-generated noise')
    parser.add_argument('--skip_noise_generation', action='store_true',
                        help='Skip the noise pre-generation step if noise already exists')

    return parser.parse_args()


def get_model(model_name, dataset_name):
    """Create model instance based on model name and dataset name."""
    in_channels = 3  # For CIFAR-10
    num_classes = 10  # For CIFAR-10

    if model_name == 'resnet18':
        return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
    else:
        raise ValueError(f"Unsupported model for this script: {model_name}")


def main():
    args = parse_args()

    if args.device:
        device = torch.device(args.device)
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Using device: {device}")

    print(f"Creating {args.model} model for {args.dataset} dataset")
    model = get_model(args.model, args.dataset)
    print(
        f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    class_indices = [int(idx) for idx in args.class_idxs.split(',')]
    print(f"Will independently forget classes: {class_indices}")

    if args.log_dir:
        log_dir = args.log_dir
    else:
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        log_dir = f"logs/ablation/perturbation/{args.dataset}/{args.model}/k_{args.k_percent}_{timestamp}"

    # --- Initialize Unlearner ---
    unlearner = PerturbationUnlearning(
        model=model,
        dataset_name=args.dataset,
        checkpoint_path=args.model_path,
        batch_size=args.batch_size,
        device=device,
        log_dir=log_dir,
        seed=args.seed,
        k_percent=args.k_percent,
        noise_data_dir=args.noise_data_dir
    )

    # --- Step 1: Pre-generate noise for all classes ---
    if not args.skip_noise_generation:
        print("Starting noise pre-generation for all classes...")
        unlearner.pregenerate_and_save_all_noise()
        print("Noise pre-generation complete.")
    else:
        print("Skipping noise pre-generation as requested.")

    # --- Step 2: Perform independent unlearning for each class ---
    results = []
    for i, class_idx in enumerate(class_indices):
        print(
            f"========== Starting independent unlearning for class {class_idx} ({i+1}/{len(class_indices)}) ==========")
        result = unlearner.unlearn_class(class_idx)
        results.append(result)
        print("-" * 80)

    # --- Step 3: Save results to CSV ---
    if results:
        print("Saving results to CSV...")
        df_data = []
        for res in results:
            df_data.append({
                'forgotten_class_idx': res['class_idx'],
                'forgotten_class_name': res['class_name'],
                'initial_retain_accuracy': res['initial']['active_test']['accuracy'],
                'final_retain_accuracy': res['final']['active_test']['accuracy'],
                'initial_forget_accuracy': res['initial']['class_test']['accuracy'],
                'final_forget_accuracy': res['final']['class_test']['accuracy'],
            })

        df = pd.DataFrame(df_data)

        os.makedirs(args.results_dir, exist_ok=True)
        csv_filename = f"independent_perturbation_{args.model}_{args.dataset}_k_{args.k_percent}.csv"
        csv_path = os.path.join(args.results_dir, csv_filename)

        df.to_csv(csv_path, index=False)
        print(f"Results saved to {csv_path}")
    else:
        print("No results were generated from the unlearning process.")

    return results


if __name__ == "__main__":
    main()