"""
Experiment 4: Unlearning Controllability Analysis
Test the controllability of unlearning degree parameter k.
Dataset: CIFAR-10 (class unlearning)
Framework: Complete DataOpt
"""

import os
import sys
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import logging
from typing import Dict, List, Tuple, Any
import json
import argparse

# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.dataopt import dataopt_algorithm, DataOptFramework
from baselines.classification import NEGGRADUnlearning
from utils.metrics import UnlearningMetrics, ResultLogger

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def load_cifar10_data(batch_size: int = 64) -> Tuple[DataLoader, DataLoader]:
    """Load CIFAR-10 dataset"""
    
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train
    )
    
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader


class SimpleResNet10(nn.Module):
    """Simple ResNet for CIFAR-10"""
    
    def __init__(self, num_classes: int = 10):
        super(SimpleResNet10, self).__init__()
        self.backbone = torchvision.models.resnet18(pretrained=False)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.backbone(x)


def pretrain_model(model: nn.Module, 
                   train_loader: DataLoader, 
                   epochs: int = 20,
                   device: str = 'cuda') -> nn.Module:
    """Pre-train model on full dataset"""
    logger.info("Pre-training model...")
    
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
        
        acc = 100.0 * correct / total
        avg_loss = total_loss / len(train_loader)
        
        if epoch % 5 == 0:
            logger.info(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Acc = {acc:.2f}%")
    
    logger.info("Pre-training completed")
    return model


def create_forget_retain_split(dataset: torch.utils.data.Dataset,
                              forget_class: int = 0) -> Tuple[List[int], List[int]]:
    """Create forget and retain indices for class unlearning"""
    
    forget_indices = []
    retain_indices = []
    
    for idx, (_, label) in enumerate(dataset):
        if label == forget_class:
            forget_indices.append(idx)
        else:
            retain_indices.append(idx)
    
    return forget_indices, retain_indices


def run_dataopt_with_k(model: nn.Module,
                      dataset: torch.utils.data.Dataset,
                      forget_indices: List[int],
                      retain_indices: List[int],
                      k: int,
                      device: str = 'cuda') -> Tuple[nn.Module, Dict[str, float]]:
    """Run DataOpt unlearning with specific k value"""
    
    logger.info(f"Running DataOpt with k={k}...")
    
    import copy
    model_copy = copy.deepcopy(model)
    
    # Prepare data tensors
    forget_data = []
    for idx in forget_indices:
        data, _ = dataset[idx]
        forget_data.append(data)
    
    forget_tensor = torch.stack(forget_data)
    
    # Get retain pool (limit size for efficiency)
    retain_pool_indices = retain_indices[:1000]
    retain_data = []
    for idx in retain_pool_indices:
        data, _ = dataset[idx]
        retain_data.append(data)
    
    retain_tensor = torch.stack(retain_data)
    
    # Apply DataOpt algorithm
    opt_forget_data, opt_forget_labels, opt_retain_data, opt_retain_labels = dataopt_algorithm(
        model_copy, forget_tensor, retain_tensor, k=k, device=device
    )
    
    # Create data loaders for unlearning
    from torch.utils.data import TensorDataset
    
    opt_forget_dataset = TensorDataset(opt_forget_data, opt_forget_labels.argmax(dim=1))
    opt_retain_dataset = TensorDataset(opt_retain_data, opt_retain_labels.argmax(dim=1))
    
    opt_forget_loader = DataLoader(opt_forget_dataset, batch_size=64, shuffle=True)
    opt_retain_loader = DataLoader(opt_retain_dataset, batch_size=64, shuffle=True)
    
    # Apply base unlearning algorithm (NEGGRAD)
    unlearner = NEGGRADUnlearning(model_copy, device)
    unlearned_model = unlearner.unlearn(opt_forget_loader, opt_retain_loader)
    
    # Evaluate
    forget_subset = Subset(dataset, forget_indices)
    retain_subset = Subset(dataset, retain_pool_indices)
    
    forget_loader = DataLoader(forget_subset, batch_size=64, shuffle=False)
    retain_loader = DataLoader(retain_subset, batch_size=64, shuffle=False)
    
    # Create test loader
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    )
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    metrics = UnlearningMetrics(unlearned_model, device)
    evaluation_results = metrics.evaluate_classification(
        forget_loader, retain_loader, test_loader, model
    )
    
    return unlearned_model, evaluation_results


def analyze_label_distribution(model: nn.Module,
                              forget_tensor: torch.Tensor,
                              k_values: List[int],
                              device: str = 'cuda') -> Dict[int, torch.Tensor]:
    """Analyze how label distributions change with different k values"""
    
    logger.info("Analyzing label distributions for different k values...")
    
    dataopt = DataOptFramework(model, device)
    label_distributions = {}
    
    for k in k_values:
        opt_labels = dataopt.assign_forget_labels(forget_tensor, k)
        label_distributions[k] = opt_labels
        
        logger.info(f"k={k}: Average entropy = {torch.mean(-torch.sum(opt_labels * torch.log(opt_labels + 1e-8), dim=1)):.4f}")
    
    return label_distributions


def run_controllability_experiment(model: nn.Module,
                                 dataset: torch.utils.data.Dataset,
                                 forget_indices: List[int],
                                 retain_indices: List[int],
                                 k_values: List[int],
                                 device: str = 'cuda') -> Dict[int, Dict[str, float]]:
    """Run controllability experiment with different k values"""
    
    results = {}
    
    # Analyze label distributions
    forget_data = []
    for idx in forget_indices:
        data, _ = dataset[idx]
        forget_data.append(data)
    forget_tensor = torch.stack(forget_data)
    
    label_distributions = analyze_label_distribution(model, forget_tensor, k_values, device)
    
    # Test each k value
    for k in k_values:
        logger.info(f"Testing controllability with k={k}")
        
        try:
            _, metrics = run_dataopt_with_k(
                model, dataset, forget_indices, retain_indices, k, device
            )
            
            # Add label distribution info
            metrics['label_entropy'] = torch.mean(
                -torch.sum(label_distributions[k] * torch.log(label_distributions[k] + 1e-8), dim=1)
            ).item()
            
            results[k] = metrics
            
            logger.info(f"k={k} results: Acc_rt={metrics['acc_rt']:.3f}, "
                       f"Acc_ft={metrics['acc_ft']:.3f}, MIA={metrics['mia']:.3f}")
            
        except Exception as e:
            logger.error(f"Error with k={k}: {e}")
            continue
    
    return results


def main():
    parser = argparse.ArgumentParser(description='Experiment 4: Unlearning Controllability')
    parser.add_argument('--forget_class', type=int, default=0, 
                       help='Class to forget (0-9)')
    parser.add_argument('--k_values', nargs='+', type=int,
                       default=[1, 3, 5, 7, 9],
                       help='Unlearning degree values to test')
    parser.add_argument('--device', default='cuda', help='Device to use')
    parser.add_argument('--output_dir', default='results', help='Output directory')
    parser.add_argument('--runs', type=int, default=3,
                       help='Number of runs for each k value')
    
    args = parser.parse_args()
    
    # Setup output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize result logger
    result_logger = ResultLogger(args.output_dir)
    
    logger.info("Starting Unlearning Controllability Analysis...")
    
    # Load data
    train_loader, test_loader = load_cifar10_data()
    
    # Create and pre-train model
    model = SimpleResNet10(num_classes=10)
    model = pretrain_model(model, train_loader, epochs=15, device=args.device)
    
    # Get full dataset for splitting
    full_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    )
    
    # Create forget/retain split
    forget_indices, retain_indices = create_forget_retain_split(
        full_dataset, args.forget_class
    )
    
    logger.info(f"Forget samples: {len(forget_indices)}, Retain samples: {len(retain_indices)}")
    
    # Run multiple runs for statistical significance
    all_results = {}
    
    for run in range(args.runs):
        logger.info(f"Starting run {run + 1}/{args.runs}")
        
        run_results = run_controllability_experiment(
            model, full_dataset, forget_indices, retain_indices, 
            args.k_values, args.device
        )
        
        # Store results
        for k, metrics in run_results.items():
            if k not in all_results:
                all_results[k] = []
            all_results[k].append(metrics)
    
    # Compute statistics
    final_results = {}
    for k, run_metrics in all_results.items():
        final_results[k] = {}
        
        # Compute mean and std for each metric
        metric_keys = run_metrics[0].keys()
        for metric in metric_keys:
            values = [run[metric] for run in run_metrics]
            final_results[k][f'{metric}_mean'] = np.mean(values)
            final_results[k][f'{metric}_std'] = np.std(values)
    
    # Log results
    for k, metrics in final_results.items():
        result_logger.log_results(
            experiment_name='exp4_controllability',
            method_name=f'DataOpt_k{k}',
            dataset='cifar10',
            metrics=metrics,
            hyperparams={
                'forget_class': args.forget_class,
                'k_value': k,
                'num_runs': args.runs
            }
        )
    
    # Save summary
    summary_file = os.path.join(args.output_dir, 'exp4_summary.json')
    with open(summary_file, 'w') as f:
        json.dump(final_results, f, indent=2)
    
    logger.info(f"Experiment completed. Results saved to {summary_file}")
    
    # Print summary
    print("\n" + "="*70)
    print("EXPERIMENT 4 SUMMARY - UNLEARNING CONTROLLABILITY")
    print("="*70)
    print(f"Forget Class: {args.forget_class}")
    print(f"Number of runs: {args.runs}")
    print("-" * 70)
    print(f"{'k':>3} | {'Acc_rt':>8} | {'Acc_ft':>8} | {'MIA':>8} | {'RUD':>8} | {'Entropy':>8}")
    print("-" * 70)
    
    for k in sorted(args.k_values):
        if k in final_results:
            metrics = final_results[k]
            acc_rt = f"{metrics['acc_rt_mean']:.3f}±{metrics['acc_rt_std']:.3f}"
            acc_ft = f"{metrics['acc_ft_mean']:.3f}±{metrics['acc_ft_std']:.3f}"
            mia = f"{metrics['mia_mean']:.3f}±{metrics['mia_std']:.3f}"
            rud = f"{metrics['rud_mean']:.3f}±{metrics['rud_std']:.3f}"
            entropy = f"{metrics.get('label_entropy_mean', 0):.3f}±{metrics.get('label_entropy_std', 0):.3f}"
            
            print(f"{k:>3} | {acc_rt:>8} | {acc_ft:>8} | {mia:>8} | {rud:>8} | {entropy:>8}")
    
    # Analysis
    print("\nControllability Analysis:")
    print("-" * 40)
    
    # Check if Acc_ft decreases with k
    acc_ft_means = [final_results[k]['acc_ft_mean'] for k in sorted(args.k_values) if k in final_results]
    if len(acc_ft_means) > 1:
        trend = "decreasing" if acc_ft_means[-1] < acc_ft_means[0] else "increasing"
        print(f"Forget accuracy trend: {trend}")
    
    # Check if RUD remains stable
    rud_means = [final_results[k]['rud_mean'] for k in sorted(args.k_values) if k in final_results]
    if len(rud_means) > 1:
        rud_variation = np.std(rud_means)
        print(f"RUD stability (lower is better): {rud_variation:.4f}")
    
    # Best controllability point
    if final_results:
        best_k = min(final_results.keys(), 
                    key=lambda x: final_results[x]['acc_ft_mean'])
        print(f"Most effective unlearning (lowest Acc_ft): k={best_k}")


if __name__ == "__main__":
    main()