"""
Experiment 1: SOTA Enhancement
Test DataOpt framework enhancement on SOTA unlearning methods.
Datasets: CIFAR-100 (class unlearning), Tiny-ImageNet (random subset unlearning)
"""

import os
import sys
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset, TensorDataset
import numpy as np
import logging
from typing import Dict, List, Tuple, Any
import json
import argparse
from datetime import datetime

# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from baselines.classification import (
    NEGGRADUnlearning, SCRUBUnlearning, BadTeacherUnlearning, 
    SalUnUnlearning, DELETEUnlearning
)
from src.dataopt import dataopt_algorithm
from utils.metrics import UnlearningMetrics, ResultLogger

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class SimpleResNet(nn.Module):
    """Simple ResNet-18 for experiments"""
    
    def __init__(self, num_classes: int = 100):
        super(SimpleResNet, self).__init__()
        self.backbone = torchvision.models.resnet18(pretrained=False)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.backbone(x)


def load_cifar100_data(batch_size: int = 64) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Load CIFAR-100 dataset"""
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])
    
    train_dataset = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=True, transform=transform_train
    )
    
    test_dataset = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True, transform=transform_test
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader


def load_tiny_imagenet_data(batch_size: int = 64) -> Tuple[DataLoader, DataLoader]:
    """Load Tiny-ImageNet dataset (simplified version)"""
    # For this implementation, we'll simulate Tiny-ImageNet with CIFAR-100
    # In practice, you would download and process the actual Tiny-ImageNet dataset
    
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    
    # Use CIFAR-100 as a proxy for Tiny-ImageNet
    train_dataset = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=True, transform=transform
    )
    
    test_dataset = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True, transform=transform
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader


def pretrain_model(model: nn.Module, 
                   train_loader: DataLoader, 
                   epochs: int = 20,
                   device: str = 'cuda') -> nn.Module:
    """Pre-train model on full dataset"""
    logger.info("Pre-training model...")
    
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
        
        acc = 100.0 * correct / total
        avg_loss = total_loss / len(train_loader)
        
        if epoch % 5 == 0:
            logger.info(f"Epoch {epoch}: Loss = {avg_loss:.4f}, Acc = {acc:.2f}%")
    
    logger.info("Pre-training completed")
    return model


def create_forget_retain_split(dataset: torch.utils.data.Dataset,
                              forget_classes: List[int] = None,
                              forget_ratio: float = None) -> Tuple[DataLoader, DataLoader]:
    """Create forget and retain splits"""
    
    if forget_classes is not None:
        # Class unlearning
        forget_indices = []
        retain_indices = []
        
        for idx, (_, label) in enumerate(dataset):
            if label in forget_classes:
                forget_indices.append(idx)
            else:
                retain_indices.append(idx)
    
    elif forget_ratio is not None:
        # Random subset unlearning
        num_samples = len(dataset)
        num_forget = int(num_samples * forget_ratio)
        
        indices = np.random.permutation(num_samples)
        forget_indices = indices[:num_forget].tolist()
        retain_indices = indices[num_forget:].tolist()
    
    else:
        raise ValueError("Must specify either forget_classes or forget_ratio")
    
    forget_subset = Subset(dataset, forget_indices)
    retain_subset = Subset(dataset, retain_indices)
    
    forget_loader = DataLoader(forget_subset, batch_size=64, shuffle=True)
    retain_loader = DataLoader(retain_subset, batch_size=64, shuffle=True)
    
    return forget_loader, retain_loader


def run_baseline_experiment(baseline_name: str,
                           model: nn.Module,
                           forget_loader: DataLoader,
                           retain_loader: DataLoader,
                           device: str = 'cuda') -> nn.Module:
    """Run baseline unlearning method"""
    
    baseline_classes = {
        'NEGGRAD': NEGGRADUnlearning,
        'SCRUB': SCRUBUnlearning,
        'BadTeacher': BadTeacherUnlearning,
        'SalUn': SalUnUnlearning,
        'DELETE': DELETEUnlearning
    }
    
    if baseline_name not in baseline_classes:
        raise ValueError(f"Unknown baseline: {baseline_name}")
    
    logger.info(f"Running {baseline_name} baseline...")
    
    # Create copy of model for unlearning
    import copy
    model_copy = copy.deepcopy(model)
    
    baseline = baseline_classes[baseline_name](model_copy, device)
    unlearned_model = baseline.unlearn(forget_loader, retain_loader)
    
    return unlearned_model


def run_dataopt_enhanced_experiment(baseline_name: str,
                                   model: nn.Module,
                                   forget_loader: DataLoader,
                                   retain_pool_loader: DataLoader,
                                   k: int = 9,
                                   device: str = 'cuda') -> nn.Module:
    """Run DataOpt-enhanced baseline"""
    
    logger.info(f"Running {baseline_name} + DataOpt...")
    
    import copy
    model_copy = copy.deepcopy(model)
    
    # Convert data loaders to tensors for DataOpt
    forget_data, forget_targets = [], []
    for data, target in forget_loader:
        forget_data.append(data)
        forget_targets.append(target)
    
    forget_tensor = torch.cat(forget_data, dim=0)
    forget_labels_tensor = torch.cat(forget_targets, dim=0)
    
    retain_data = []
    for data, _ in retain_pool_loader:
        retain_data.append(data)
        if len(retain_data) >= 10:  # Limit retain pool size for efficiency
            break
    
    retain_tensor = torch.cat(retain_data, dim=0)
    
    # Apply DataOpt optimization
    opt_forget_data, opt_forget_labels, opt_retain_data, opt_retain_labels = dataopt_algorithm(
        model_copy, forget_tensor, retain_tensor, k=k, device=device
    )
    
    # Convert back to data loaders
    opt_forget_dataset = TensorDataset(opt_forget_data, opt_forget_labels.argmax(dim=1))
    opt_retain_dataset = TensorDataset(opt_retain_data, opt_retain_labels.argmax(dim=1))
    
    opt_forget_loader = DataLoader(opt_forget_dataset, batch_size=64, shuffle=True)
    opt_retain_loader = DataLoader(opt_retain_dataset, batch_size=64, shuffle=True)
    
    # Run baseline with optimized data
    unlearned_model = run_baseline_experiment(
        baseline_name, model_copy, opt_forget_loader, opt_retain_loader, device
    )
    
    return unlearned_model


def evaluate_model(model: nn.Module,
                  forget_loader: DataLoader,
                  retain_loader: DataLoader,
                  test_loader: DataLoader,
                  original_model: nn.Module = None,
                  device: str = 'cuda') -> Dict[str, float]:
    """Evaluate unlearned model"""
    
    metrics = UnlearningMetrics(model, device)
    results = metrics.evaluate_classification(
        forget_loader, retain_loader, test_loader, original_model
    )
    
    return results


def run_cifar100_experiment(baselines: List[str], device: str = 'cuda') -> Dict[str, Any]:
    """Run CIFAR-100 class unlearning experiment"""
    
    logger.info("Starting CIFAR-100 experiment...")
    
    # Load data
    train_loader, test_loader = load_cifar100_data()
    
    # Create model and pre-train
    model = SimpleResNet(num_classes=100)
    model = pretrain_model(model, train_loader, epochs=10, device=device)
    
    # Save original model for evaluation
    import copy
    original_model = copy.deepcopy(model)
    
    # Create forget/retain split (forget 10 classes)
    forget_classes = list(range(10))  # Forget first 10 classes
    
    # Get full dataset for splitting
    full_dataset = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])
    )
    
    forget_loader, retain_loader = create_forget_retain_split(
        full_dataset, forget_classes=forget_classes
    )
    
    results = {}
    
    # Run baseline experiments
    for baseline in baselines:
        try:
            # Vanilla baseline
            vanilla_model = run_baseline_experiment(
                baseline, model, forget_loader, retain_loader, device
            )
            
            vanilla_metrics = evaluate_model(
                vanilla_model, forget_loader, retain_loader, test_loader, original_model, device
            )
            
            results[baseline] = vanilla_metrics
            
            # DataOpt-enhanced baseline
            enhanced_model = run_dataopt_enhanced_experiment(
                baseline, model, forget_loader, retain_loader, k=9, device=device
            )
            
            enhanced_metrics = evaluate_model(
                enhanced_model, forget_loader, retain_loader, test_loader, original_model, device
            )
            
            results[f"{baseline} + DataOpt"] = enhanced_metrics
            
        except Exception as e:
            logger.error(f"Error running {baseline}: {e}")
            continue
    
    return results


def run_tiny_imagenet_experiment(baselines: List[str], 
                                forget_ratios: List[float],
                                device: str = 'cuda') -> Dict[str, Any]:
    """Run Tiny-ImageNet random subset unlearning experiment"""
    
    logger.info("Starting Tiny-ImageNet experiment...")
    
    # Load data
    train_loader, test_loader = load_tiny_imagenet_data()
    
    # Create model and pre-train
    model = SimpleResNet(num_classes=200)  # Tiny-ImageNet has 200 classes
    model = pretrain_model(model, train_loader, epochs=10, device=device)
    
    # Save original model
    import copy
    original_model = copy.deepcopy(model)
    
    # Get full dataset
    full_dataset = torchvision.datasets.CIFAR100(  # Using CIFAR-100 as proxy
        root='./data', train=True, download=True,
        transform=transforms.Compose([
            transforms.Resize(64),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    )
    
    results = {}
    
    for forget_ratio in forget_ratios:
        logger.info(f"Testing forget ratio: {forget_ratio}")
        
        # Create forget/retain split
        forget_loader, retain_loader = create_forget_retain_split(
            full_dataset, forget_ratio=forget_ratio
        )
        
        ratio_results = {}
        
        for baseline in baselines:
            try:
                # Vanilla baseline
                vanilla_model = run_baseline_experiment(
                    baseline, model, forget_loader, retain_loader, device
                )
                
                vanilla_metrics = evaluate_model(
                    vanilla_model, forget_loader, retain_loader, test_loader, original_model, device
                )
                
                ratio_results[baseline] = vanilla_metrics
                
                # DataOpt-enhanced baseline
                enhanced_model = run_dataopt_enhanced_experiment(
                    baseline, model, forget_loader, retain_loader, k=9, device=device
                )
                
                enhanced_metrics = evaluate_model(
                    enhanced_model, forget_loader, retain_loader, test_loader, original_model, device
                )
                
                ratio_results[f"{baseline} + DataOpt"] = enhanced_metrics
                
            except Exception as e:
                logger.error(f"Error running {baseline} with ratio {forget_ratio}: {e}")
                continue
        
        results[f"forget_ratio_{forget_ratio}"] = ratio_results
    
    return results


def main():
    parser = argparse.ArgumentParser(description='Experiment 1: SOTA Enhancement')
    parser.add_argument('--dataset', choices=['cifar100', 'tiny-imagenet', 'both'], 
                       default='both', help='Dataset to run experiments on')
    parser.add_argument('--baselines', nargs='+', 
                       default=['NEGGRAD', 'SCRUB', 'BadTeacher', 'SalUn'],
                       help='Baseline methods to test')
    parser.add_argument('--device', default='cuda', help='Device to use')
    parser.add_argument('--output_dir', default='results', help='Output directory')
    
    args = parser.parse_args()
    
    # Setup output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize result logger
    result_logger = ResultLogger(args.output_dir)
    
    all_results = {}
    
    # Run CIFAR-100 experiment
    if args.dataset in ['cifar100', 'both']:
        logger.info("Running CIFAR-100 experiments...")
        cifar100_results = run_cifar100_experiment(args.baselines, args.device)
        all_results['cifar100'] = cifar100_results
        
        # Log results
        for method, metrics in cifar100_results.items():
            result_logger.log_results(
                experiment_name='exp1_sota_enhancement',
                method_name=method,
                dataset='cifar100',
                metrics=metrics,
                hyperparams={'forget_classes': list(range(10))}
            )
    
    # Run Tiny-ImageNet experiment
    if args.dataset in ['tiny-imagenet', 'both']:
        logger.info("Running Tiny-ImageNet experiments...")
        forget_ratios = [0.01, 0.05, 0.10]
        tiny_results = run_tiny_imagenet_experiment(
            args.baselines, forget_ratios, args.device
        )
        all_results['tiny_imagenet'] = tiny_results
        
        # Log results
        for ratio_key, ratio_results in tiny_results.items():
            for method, metrics in ratio_results.items():
                result_logger.log_results(
                    experiment_name='exp1_sota_enhancement',
                    method_name=method,
                    dataset=f'tiny_imagenet_{ratio_key}',
                    metrics=metrics,
                    hyperparams={'forget_ratio': ratio_key.split('_')[-1]}
                )
    
    # Save summary results
    summary_file = os.path.join(args.output_dir, 'exp1_summary.json')
    with open(summary_file, 'w') as f:
        json.dump(all_results, f, indent=2)
    
    logger.info(f"Experiment completed. Results saved to {summary_file}")
    
    # Print summary
    print("\n" + "="*50)
    print("EXPERIMENT 1 SUMMARY")
    print("="*50)
    
    for dataset, dataset_results in all_results.items():
        print(f"\n{dataset.upper()} Results:")
        print("-" * 30)
        
        if dataset == 'cifar100':
            for method, metrics in dataset_results.items():
                print(f"{method:25} | Acc_rt: {metrics.get('acc_rt', 0):.3f} | "
                      f"Acc_ft: {metrics.get('acc_ft', 0):.3f} | "
                      f"MIA: {metrics.get('mia', 0):.3f} | "
                      f"RUD: {metrics.get('rud', 0):.3f}")
        
        elif dataset == 'tiny_imagenet':
            for ratio_key, ratio_results in dataset_results.items():
                print(f"\n{ratio_key}:")
                for method, metrics in ratio_results.items():
                    print(f"  {method:23} | Acc_rt: {metrics.get('acc_rt', 0):.3f} | "
                          f"Acc_ft: {metrics.get('acc_ft', 0):.3f} | "
                          f"MIA: {metrics.get('mia', 0):.3f} | "
                          f"RUD: {metrics.get('rud', 0):.3f}")


if __name__ == "__main__":
    main()