"""
Experiment 3: Retain Set Composition Analysis
Analyze the impact of different retain set selection strategies.
Dataset: CIFAR-10 (class unlearning)
Fixed unlearning algorithm: NEGGRAD
"""

import os
import sys
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import logging
from typing import Dict, List, Tuple, Any
import json
import argparse

# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from baselines.classification import NEGGRADUnlearning
from src.dataopt import DataOptFramework
from utils.metrics import UnlearningMetrics, ResultLogger

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def load_cifar10_data(batch_size: int = 64) -> Tuple[DataLoader, DataLoader]:
    """Load CIFAR-10 dataset"""
    
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train
    )
    
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader


class SimpleResNet10(nn.Module):
    """Simple ResNet for CIFAR-10"""
    
    def __init__(self, num_classes: int = 10):
        super(SimpleResNet10, self).__init__()
        self.backbone = torchvision.models.resnet18(pretrained=False)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.backbone(x)


def pretrain_model(model: nn.Module, 
                   train_loader: DataLoader, 
                   epochs: int = 20,
                   device: str = 'cuda') -> nn.Module:
    """Pre-train model on full dataset"""
    logger.info("Pre-training model...")
    
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
        
        acc = 100.0 * correct / total
        avg_loss = total_loss / len(train_loader)
        
        if epoch % 5 == 0:
            logger.info(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Acc = {acc:.2f}%")
    
    logger.info("Pre-training completed")
    return model


def create_forget_retain_split(dataset: torch.utils.data.Dataset,
                              forget_class: int = 0) -> Tuple[List[int], List[int]]:
    """Create forget and retain indices for class unlearning"""
    
    forget_indices = []
    retain_indices = []
    
    for idx, (_, label) in enumerate(dataset):
        if label == forget_class:
            forget_indices.append(idx)
        else:
            retain_indices.append(idx)
    
    return forget_indices, retain_indices


def select_random_retain_samples(retain_indices: List[int],
                                num_samples: int,
                                seed: int = 42) -> List[int]:
    """Random sampling baseline"""
    np.random.seed(seed)
    if len(retain_indices) <= num_samples:
        return retain_indices
    return np.random.choice(retain_indices, num_samples, replace=False).tolist()


def select_neighborhood_samples(model: nn.Module,
                               dataset: torch.utils.data.Dataset,
                               forget_indices: List[int],
                               retain_indices: List[int],
                               num_samples: int,
                               device: str = 'cuda') -> List[int]:
    """Select neighborhood samples using feature similarity"""
    
    logger.info("Selecting neighborhood samples...")
    
    dataopt = DataOptFramework(model, device)
    
    # Get forget samples
    forget_data = []
    for idx in forget_indices:
        data, _ = dataset[idx]
        forget_data.append(data)
    
    forget_tensor = torch.stack(forget_data)
    
    # Get retain pool
    retain_data = []
    for idx in retain_indices[:1000]:  # Limit pool size for efficiency
        data, _ = dataset[idx]
        retain_data.append(data)
    
    retain_tensor = torch.stack(retain_data)
    
    # Find neighborhood samples
    neighborhood_samples = dataopt.find_neighborhood_samples(
        forget_tensor, retain_tensor, k1=num_samples // len(forget_indices)
    )
    
    # Map back to original indices
    selected_indices = []
    for i in range(len(neighborhood_samples)):
        # Find corresponding index in retain_indices
        if i < len(retain_indices[:1000]):
            selected_indices.append(retain_indices[i])
    
    return selected_indices[:num_samples]


def select_boundary_samples(model: nn.Module,
                           dataset: torch.utils.data.Dataset,
                           forget_indices: List[int],
                           retain_indices: List[int],
                           num_samples: int,
                           device: str = 'cuda') -> List[int]:
    """Select boundary samples near decision boundaries"""
    
    logger.info("Selecting boundary samples...")
    
    model.eval()
    boundary_samples = []
    
    # Get forget class
    forget_class = dataset[forget_indices[0]][1]
    
    # Find samples close to decision boundary with forget class
    with torch.no_grad():
        for idx in retain_indices:
            if len(boundary_samples) >= num_samples:
                break
                
            data, label = dataset[idx]
            data = data.unsqueeze(0).to(device)
            
            output = model(data)
            probs = torch.softmax(output, dim=1)
            
            # Check if this sample has significant probability for forget class
            forget_class_prob = probs[0, forget_class].item()
            predicted_class_prob = probs.max().item()
            
            # Select samples where forget class has reasonable probability
            # but is not the top prediction
            if forget_class_prob > 0.1 and predicted_class_prob - forget_class_prob < 0.5:
                boundary_samples.append(idx)
    
    # If we don't have enough, fill with random samples
    if len(boundary_samples) < num_samples:
        remaining = num_samples - len(boundary_samples)
        additional = select_random_retain_samples(
            [idx for idx in retain_indices if idx not in boundary_samples], 
            remaining
        )
        boundary_samples.extend(additional)
    
    return boundary_samples[:num_samples]


def select_dataopt_samples(model: nn.Module,
                          dataset: torch.utils.data.Dataset,
                          forget_indices: List[int],
                          retain_indices: List[int],
                          num_samples: int,
                          device: str = 'cuda') -> List[int]:
    """Select samples using full DataOpt strategy"""
    
    logger.info("Selecting DataOpt mixed samples...")
    
    # Divide samples among different types
    num_neighborhood = num_samples // 3
    num_boundary = num_samples // 3
    num_adversarial = num_samples - num_neighborhood - num_boundary
    
    # Select neighborhood samples
    neighborhood_indices = select_neighborhood_samples(
        model, dataset, forget_indices, retain_indices, num_neighborhood, device
    )
    
    # Select boundary samples
    boundary_indices = select_boundary_samples(
        model, dataset, forget_indices, retain_indices, num_boundary, device
    )
    
    # For adversarial, just select random samples (adversarial generation is complex)
    remaining_indices = [idx for idx in retain_indices 
                        if idx not in neighborhood_indices + boundary_indices]
    adversarial_indices = select_random_retain_samples(remaining_indices, num_adversarial)
    
    return neighborhood_indices + boundary_indices + adversarial_indices


def run_retain_composition_experiment(model: nn.Module,
                                    dataset: torch.utils.data.Dataset,
                                    forget_indices: List[int],
                                    retain_indices: List[int],
                                    retain_strategies: List[str],
                                    num_retain_samples: int = 200,
                                    device: str = 'cuda') -> Dict[str, Dict[str, float]]:
    """Run retain set composition experiment"""
    
    results = {}
    
    # Create forget loader
    forget_subset = Subset(dataset, forget_indices)
    forget_loader = DataLoader(forget_subset, batch_size=64, shuffle=True)
    
    # Create test loader for evaluation
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    )
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    for strategy in retain_strategies:
        logger.info(f"Testing retain strategy: {strategy}")
        
        try:
            # Select retain samples based on strategy
            if strategy == "Random":
                selected_retain_indices = select_random_retain_samples(
                    retain_indices, num_retain_samples
                )
            elif strategy == "Neighborhood":
                selected_retain_indices = select_neighborhood_samples(
                    model, dataset, forget_indices, retain_indices, num_retain_samples, device
                )
            elif strategy == "Boundary":
                selected_retain_indices = select_boundary_samples(
                    model, dataset, forget_indices, retain_indices, num_retain_samples, device
                )
            elif strategy == "DataOpt":
                selected_retain_indices = select_dataopt_samples(
                    model, dataset, forget_indices, retain_indices, num_retain_samples, device
                )
            else:
                raise ValueError(f"Unknown strategy: {strategy}")
            
            # Create retain loader
            retain_subset = Subset(dataset, selected_retain_indices)
            retain_loader = DataLoader(retain_subset, batch_size=64, shuffle=True)
            
            # Create model copy for unlearning
            import copy
            model_copy = copy.deepcopy(model)
            
            # Run NEGGRAD unlearning
            unlearner = NEGGRADUnlearning(model_copy, device)
            unlearned_model = unlearner.unlearn(forget_loader, retain_loader)
            
            # Evaluate
            metrics = UnlearningMetrics(unlearned_model, device)
            evaluation_results = metrics.evaluate_classification(
                forget_loader, retain_loader, test_loader, model
            )
            
            results[strategy] = evaluation_results
            
            logger.info(f"{strategy} results: {evaluation_results}")
            
        except Exception as e:
            logger.error(f"Error with strategy {strategy}: {e}")
            continue
    
    return results


def main():
    parser = argparse.ArgumentParser(description='Experiment 3: Retain Set Composition')
    parser.add_argument('--forget_class', type=int, default=0, 
                       help='Class to forget (0-9)')
    parser.add_argument('--num_retain_samples', type=int, default=200,
                       help='Number of retain samples to use')
    parser.add_argument('--strategies', nargs='+', 
                       default=['Random', 'Neighborhood', 'Boundary', 'DataOpt'],
                       help='Retain set selection strategies to test')
    parser.add_argument('--device', default='cuda', help='Device to use')
    parser.add_argument('--output_dir', default='results', help='Output directory')
    
    args = parser.parse_args()
    
    # Setup output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize result logger
    result_logger = ResultLogger(args.output_dir)
    
    logger.info("Starting Retain Set Composition Analysis...")
    
    # Load data
    train_loader, test_loader = load_cifar10_data()
    
    # Create and pre-train model
    model = SimpleResNet10(num_classes=10)
    model = pretrain_model(model, train_loader, epochs=15, device=args.device)
    
    # Get full dataset for splitting
    full_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    )
    
    # Create forget/retain split
    forget_indices, retain_indices = create_forget_retain_split(
        full_dataset, args.forget_class
    )
    
    logger.info(f"Forget samples: {len(forget_indices)}, Retain samples: {len(retain_indices)}")
    
    # Run experiment
    results = run_retain_composition_experiment(
        model, full_dataset, forget_indices, retain_indices, 
        args.strategies, args.num_retain_samples, args.device
    )
    
    # Log results
    for strategy, metrics in results.items():
        result_logger.log_results(
            experiment_name='exp3_retain_composition',
            method_name=f'NEGGRAD_{strategy}',
            dataset='cifar10',
            metrics=metrics,
            hyperparams={
                'forget_class': args.forget_class,
                'num_retain_samples': args.num_retain_samples,
                'retain_strategy': strategy
            }
        )
    
    # Save summary
    summary_file = os.path.join(args.output_dir, 'exp3_summary.json')
    with open(summary_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    logger.info(f"Experiment completed. Results saved to {summary_file}")
    
    # Print summary
    print("\n" + "="*60)
    print("EXPERIMENT 3 SUMMARY - RETAIN SET COMPOSITION")
    print("="*60)
    print(f"Forget Class: {args.forget_class}")
    print(f"Retain Samples Used: {args.num_retain_samples}")
    print("-" * 60)
    
    for strategy, metrics in results.items():
        print(f"{strategy:15} | Acc_rt: {metrics.get('acc_rt', 0):.3f} | "
              f"Acc_ft: {metrics.get('acc_ft', 0):.3f} | "
              f"MIA: {metrics.get('mia', 0):.3f} | "
              f"RUD: {metrics.get('rud', 0):.3f}")
    
    # Analysis
    print("\nAnalysis:")
    print("-" * 30)
    
    best_strategy = min(results.keys(), 
                       key=lambda x: results[x].get('rud', float('inf')))
    print(f"Best strategy (lowest RUD): {best_strategy}")
    
    best_forget = min(results.keys(), 
                     key=lambda x: results[x].get('acc_ft', float('inf')))
    print(f"Best forgetting (lowest Acc_ft): {best_forget}")
    
    best_retain = max(results.keys(), 
                     key=lambda x: results[x].get('acc_rt', 0))
    print(f"Best retention (highest Acc_rt): {best_retain}")


if __name__ == "__main__":
    main()