"""
Experiment 5: DELETE Framework Label Strategy Comparison
Compare DataOpt label assignment with DELETE label strategy.
Dataset: CIFAR-10 (class unlearning, no retain set)
"""

import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import logging
from typing import Dict, List, Tuple, Any
import json
import argparse

# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.dataopt import DataOptFramework
from utils.metrics import UnlearningMetrics, ResultLogger

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def load_cifar10_data(batch_size: int = 64) -> Tuple[DataLoader, DataLoader]:
    """Load CIFAR-10 dataset"""
    
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train
    )
    
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader


class SimpleResNet10(nn.Module):
    """Simple ResNet for CIFAR-10"""
    
    def __init__(self, num_classes: int = 10):
        super(SimpleResNet10, self).__init__()
        self.backbone = torchvision.models.resnet18(pretrained=False)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.backbone(x)


def pretrain_model(model: nn.Module, 
                   train_loader: DataLoader, 
                   epochs: int = 20,
                   device: str = 'cuda') -> nn.Module:
    """Pre-train model on full dataset"""
    logger.info("Pre-training model...")
    
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
        
        acc = 100.0 * correct / total
        avg_loss = total_loss / len(train_loader)
        
        if epoch % 5 == 0:
            logger.info(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Acc = {acc:.2f}%")
    
    logger.info("Pre-training completed")
    return model


def create_delete_labels(model: nn.Module,
                        forget_data: torch.Tensor,
                        forget_targets: torch.Tensor,
                        device: str = 'cuda') -> torch.Tensor:
    """Create DELETE-style labels"""
    
    logger.info("Creating DELETE labels...")
    
    model.eval()
    with torch.no_grad():
        forget_data = forget_data.to(device)
        logits = model(forget_data)
        
        # Create modified logits
        modified_logits = logits.clone()
        
        # Set true class logits to large negative value
        for i, target in enumerate(forget_targets):
            modified_logits[i, target] = -1e6
        
        # Apply softmax to get new target distribution
        delete_labels = F.softmax(modified_logits, dim=1)
    
    return delete_labels


def create_dataopt_labels(model: nn.Module,
                         forget_data: torch.Tensor,
                         k: int = 9,
                         device: str = 'cuda') -> torch.Tensor:
    """Create DataOpt labels"""
    
    logger.info("Creating DataOpt labels...")
    
    dataopt = DataOptFramework(model, device)
    opt_labels = dataopt.assign_forget_labels(forget_data, k)
    
    return opt_labels


def finetune_with_labels(model: nn.Module,
                        forget_data: torch.Tensor,
                        target_labels: torch.Tensor,
                        epochs: int = 10,
                        lr: float = 0.001,
                        device: str = 'cuda') -> nn.Module:
    """Fine-tune model with target labels"""
    
    logger.info("Fine-tuning with target labels...")
    
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Create dataset
    from torch.utils.data import TensorDataset, DataLoader
    dataset = TensorDataset(forget_data, target_labels)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    for epoch in range(epochs):
        total_loss = 0.0
        num_batches = 0
        
        for data, labels in loader:
            data, labels = data.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            output = model(data)
            
            # Use KL divergence loss for soft labels
            log_probs = F.log_softmax(output, dim=1)
            loss = F.kl_div(log_probs, labels, reduction='batchmean')
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        
        if epoch % 2 == 0:
            logger.info(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
    
    logger.info("Fine-tuning completed")
    return model


def evaluate_unlearning(model: nn.Module,
                       forget_loader: DataLoader,
                       retain_test_loader: DataLoader,
                       device: str = 'cuda') -> Dict[str, float]:
    """Evaluate unlearning performance"""
    
    model.eval()
    
    # Compute forget accuracy (should be low)
    forget_correct = 0
    forget_total = 0
    
    with torch.no_grad():
        for data, target in forget_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            forget_correct += pred.eq(target).sum().item()
            forget_total += target.size(0)
    
    acc_ft = forget_correct / forget_total if forget_total > 0 else 0.0
    
    # Compute retain accuracy (should be high)
    retain_correct = 0
    retain_total = 0
    
    with torch.no_grad():
        for data, target in retain_test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            retain_correct += pred.eq(target).sum().item()
            retain_total += target.size(0)
    
    acc_rt = retain_correct / retain_total if retain_total > 0 else 0.0
    
    return {
        'acc_rt': acc_rt,
        'acc_ft': acc_ft
    }


def run_delete_comparison(model: nn.Module,
                         dataset: torch.utils.data.Dataset,
                         test_dataset: torch.utils.data.Dataset,
                         forget_class: int = 0,
                         device: str = 'cuda') -> Dict[str, Dict[str, float]]:
    """Run DELETE vs DataOpt comparison"""
    
    results = {}
    
    # Create forget indices
    forget_indices = []
    retain_test_indices = []
    
    for idx, (_, label) in enumerate(test_dataset):
        if label == forget_class:
            forget_indices.append(idx)
        else:
            retain_test_indices.append(idx)
    
    # Get forget samples from training set
    train_forget_indices = []
    for idx, (_, label) in enumerate(dataset):
        if label == forget_class:
            train_forget_indices.append(idx)
    
    # Prepare data
    forget_data = []
    forget_targets = []
    
    for idx in train_forget_indices:
        data, target = dataset[idx]
        forget_data.append(data)
        forget_targets.append(target)
    
    forget_tensor = torch.stack(forget_data)
    forget_targets_tensor = torch.tensor(forget_targets)
    
    # Create loaders for evaluation
    forget_test_subset = Subset(test_dataset, forget_indices)
    retain_test_subset = Subset(test_dataset, retain_test_indices)
    
    forget_test_loader = DataLoader(forget_test_subset, batch_size=64, shuffle=False)
    retain_test_loader = DataLoader(retain_test_subset, batch_size=64, shuffle=False)
    
    # Test DELETE label strategy
    logger.info("Testing DELETE label strategy...")
    
    import copy
    delete_model = copy.deepcopy(model)
    
    delete_labels = create_delete_labels(
        delete_model, forget_tensor, forget_targets_tensor, device
    )
    
    delete_model = finetune_with_labels(
        delete_model, forget_tensor, delete_labels, device=device
    )
    
    delete_metrics = evaluate_unlearning(
        delete_model, forget_test_loader, retain_test_loader, device
    )
    
    results['DELETE-Label'] = delete_metrics
    
    # Test DataOpt label strategy
    logger.info("Testing DataOpt label strategy...")
    
    dataopt_model = copy.deepcopy(model)
    
    dataopt_labels = create_dataopt_labels(
        dataopt_model, forget_tensor, k=9, device=device  # Use maximum unlearning degree
    )
    
    dataopt_model = finetune_with_labels(
        dataopt_model, forget_tensor, dataopt_labels, device=device
    )
    
    dataopt_metrics = evaluate_unlearning(
        dataopt_model, forget_test_loader, retain_test_loader, device
    )
    
    results['DataOpt-Label'] = dataopt_metrics
    
    return results


def main():
    parser = argparse.ArgumentParser(description='Experiment 5: DELETE Comparison')
    parser.add_argument('--forget_class', type=int, default=0, 
                       help='Class to forget (0-9)')
    parser.add_argument('--device', default='cuda', help='Device to use')
    parser.add_argument('--output_dir', default='results', help='Output directory')
    parser.add_argument('--runs', type=int, default=5,
                       help='Number of runs for statistical significance')
    
    args = parser.parse_args()
    
    # Setup output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize result logger
    result_logger = ResultLogger(args.output_dir)
    
    logger.info("Starting DELETE Framework Comparison...")
    
    # Load data
    train_loader, test_loader = load_cifar10_data()
    
    # Create and pre-train model
    model = SimpleResNet10(num_classes=10)
    model = pretrain_model(model, train_loader, epochs=15, device=args.device)
    
    # Get datasets
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    )
    
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    )
    
    # Run multiple runs for statistical significance
    all_results = {'DELETE-Label': [], 'DataOpt-Label': []}
    
    for run in range(args.runs):
        logger.info(f"Starting run {run + 1}/{args.runs}")
        
        run_results = run_delete_comparison(
            model, train_dataset, test_dataset, args.forget_class, args.device
        )
        
        for method, metrics in run_results.items():
            all_results[method].append(metrics)
    
    # Compute statistics
    final_results = {}
    for method, run_metrics in all_results.items():
        final_results[method] = {}
        
        # Compute mean and std for each metric
        for metric in ['acc_rt', 'acc_ft']:
            values = [run[metric] for run in run_metrics]
            final_results[method][f'{metric}_mean'] = np.mean(values)
            final_results[method][f'{metric}_std'] = np.std(values)
    
    # Log results
    for method, metrics in final_results.items():
        result_logger.log_results(
            experiment_name='exp5_delete_comparison',
            method_name=method,
            dataset='cifar10',
            metrics=metrics,
            hyperparams={
                'forget_class': args.forget_class,
                'num_runs': args.runs,
                'no_retain_set': True
            }
        )
    
    # Save summary
    summary_file = os.path.join(args.output_dir, 'exp5_summary.json')
    with open(summary_file, 'w') as f:
        json.dump(final_results, f, indent=2)
    
    logger.info(f"Experiment completed. Results saved to {summary_file}")
    
    # Print summary
    print("\n" + "="*60)
    print("EXPERIMENT 5 SUMMARY - DELETE COMPARISON")
    print("="*60)
    print(f"Forget Class: {args.forget_class}")
    print(f"Number of runs: {args.runs}")
    print("No retain set used (forget-only fine-tuning)")
    print("-" * 60)
    print(f"{'Method':>15} | {'Acc_rt (Retain)':>15} | {'Acc_ft (Forget)':>15}")
    print("-" * 60)
    
    for method in ['DELETE-Label', 'DataOpt-Label']:
        if method in final_results:
            metrics = final_results[method]
            acc_rt = f"{metrics['acc_rt_mean']:.3f}±{metrics['acc_rt_std']:.3f}"
            acc_ft = f"{metrics['acc_ft_mean']:.3f}±{metrics['acc_ft_std']:.3f}"
            
            print(f"{method:>15} | {acc_rt:>15} | {acc_ft:>15}")
    
    # Statistical analysis
    print("\nStatistical Analysis:")
    print("-" * 30)
    
    if 'DELETE-Label' in final_results and 'DataOpt-Label' in final_results:
        delete_acc_rt = final_results['DELETE-Label']['acc_rt_mean']
        dataopt_acc_rt = final_results['DataOpt-Label']['acc_rt_mean']
        
        delete_acc_ft = final_results['DELETE-Label']['acc_ft_mean']
        dataopt_acc_ft = final_results['DataOpt-Label']['acc_ft_mean']
        
        print(f"Retain accuracy improvement (DataOpt vs DELETE): "
              f"{dataopt_acc_rt - delete_acc_rt:+.3f}")
        print(f"Forget accuracy difference (DataOpt vs DELETE): "
              f"{dataopt_acc_ft - delete_acc_ft:+.3f}")
        
        # Determine better method
        if dataopt_acc_rt > delete_acc_rt and dataopt_acc_ft <= delete_acc_ft:
            print("DataOpt shows better trade-off (higher retain, similar/lower forget)")
        elif delete_acc_rt > dataopt_acc_rt and delete_acc_ft <= dataopt_acc_ft:
            print("DELETE shows better trade-off (higher retain, similar/lower forget)")
        else:
            print("Mixed results - no clear winner")
    
    # Label distribution analysis
    print("\nLabel Strategy Analysis:")
    print("-" * 30)
    print("DELETE strategy: Sets forget class logit to -∞, redistributes to other classes")
    print("DataOpt strategy: Optimizes label distribution based on learning dynamics theory")
    print("Key difference: DataOpt considers model's current predictions, DELETE uses uniform redistribution")


if __name__ == "__main__":
    main()