#!/usr/bin/env python3
"""

Usage:
    python cifar_experiments.py --dataset cifar10 --forget_type class --forget_class 0
    python cifar_experiments.py --dataset cifar100 --forget_type random --forget_ratio 0.1
"""

import os
import sys
import argparse
import json
import logging
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# Add src to path for imports
sys.path.append(str(Path(__file__).parent.parent / "src"))


class ResNet18_CIFAR(nn.Module):
    """ResNet-18 architecture adapted for CIFAR datasets."""
    
    def __init__(self, num_classes: int = 10):
        super(ResNet18_CIFAR, self).__init__()
        self.backbone = torchvision.models.resnet18(pretrained=False)
        # Modify first conv layer for CIFAR (32x32 input instead of 224x224)
        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.backbone.maxpool = nn.Identity()  # Remove maxpool for CIFAR
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
        
    def forward(self, x):
        return self.backbone(x)


class OFMUVisionTrainer:
    """OFMU trainer for vision tasks with bi-level optimization."""
    
    def __init__(self, model: nn.Module, device: torch.device, config: Dict):
        self.model = model
        self.device = device
        self.config = config
        
        # OFMU hyperparameters
        self.beta = config.get("beta", 0.1)
        self.rho_init = config.get("rho_init", 0.01)
        self.inner_steps = config.get("inner_steps", 5)
        self.inner_lr = config.get("inner_lr", 1e-3)
        self.outer_lr = config.get("outer_lr", 1e-3)
        
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.outer_lr, momentum=0.9, weight_decay=5e-4)
        
    def compute_similarity(self, grad_f: torch.Tensor, grad_r: torch.Tensor) -> torch.Tensor:
        """Compute cosine similarity between forget and retain gradients."""
        grad_f_flat = torch.cat([g.flatten() for g in grad_f])
        grad_r_flat = torch.cat([g.flatten() for g in grad_r])
        
        cos_sim = torch.cosine_similarity(grad_f_flat.unsqueeze(0), grad_r_flat.unsqueeze(0))
        return cos_sim.item()
    
    def compute_phi_gradient(self, forget_loader: DataLoader, retain_loader: DataLoader) -> List[torch.Tensor]:
        """Compute gradient of the inner objective Phi(theta)."""
        self.model.train()
        
        # Sample mini-batches
        forget_batch = next(iter(forget_loader))
        retain_batch = next(iter(retain_loader))
        
        forget_x, forget_y = forget_batch[0].to(self.device), forget_batch[1].to(self.device)
        retain_x, retain_y = retain_batch[0].to(self.device), retain_batch[1].to(self.device)
        
        # Compute forget loss and gradients
        forget_outputs = self.model(forget_x)
        forget_loss = self.criterion(forget_outputs, forget_y)
        
        forget_grads = torch.autograd.grad(forget_loss, self.model.parameters(), 
                                         create_graph=True, retain_graph=True)
        
        # Compute retain loss and gradients  
        retain_outputs = self.model(retain_x)
        retain_loss = self.criterion(retain_outputs, retain_y)
        
        retain_grads = torch.autograd.grad(retain_loss, self.model.parameters(),
                                         create_graph=True, retain_graph=True)
        
        # Compute similarity
        similarity = self.compute_similarity(forget_grads, retain_grads)
        
        # Compute Phi = L_f - beta * similarity
        phi_loss = forget_loss - self.beta * similarity
        
        # Compute gradients of Phi
        phi_grads = torch.autograd.grad(phi_loss, self.model.parameters(),
                                       create_graph=True, retain_graph=True)
        
        return phi_grads, phi_loss.item(), similarity
    
    def inner_maximization_step(self, forget_loader: DataLoader, retain_loader: DataLoader):
        """Perform T steps of gradient ascent on Phi(theta)."""
        for _ in range(self.inner_steps):
            phi_grads, phi_loss, similarity = self.compute_phi_gradient(forget_loader, retain_loader)
            
            # Gradient ascent step
            with torch.no_grad():
                for param, grad in zip(self.model.parameters(), phi_grads):
                    param.data += self.inner_lr * grad
                    
        return phi_loss, similarity
    
    def outer_minimization_step(self, retain_loader: DataLoader, rho: float):
        """Perform outer minimization with penalty term."""
        self.optimizer.zero_grad()
        
        # Sample retain batch
        retain_batch = next(iter(retain_loader))
        retain_x, retain_y = retain_batch[0].to(self.device), retain_batch[1].to(self.device)
        
        # Compute retain loss
        retain_outputs = self.model(retain_x)
        retain_loss = self.criterion(retain_outputs, retain_y)
        
        # Compute penalty term ||∇Phi(θ)||²
        forget_batch = next(iter(DataLoader(retain_loader.dataset, batch_size=retain_loader.batch_size, shuffle=True)))
        phi_grads, _, _ = self.compute_phi_gradient(
            DataLoader([forget_batch], batch_size=1),
            DataLoader([retain_batch], batch_size=1)
        )
        
        penalty_term = sum(torch.norm(grad)**2 for grad in phi_grads)
        
        # Total loss: L_r + rho * ||∇Phi(θ)||²
        total_loss = retain_loss + rho * penalty_term
        
        total_loss.backward()
        self.optimizer.step()
        
        return retain_loss.item(), penalty_term.item()
    
    def unlearn(self, forget_loader: DataLoader, retain_loader: DataLoader, 
                num_epochs: int = 10) -> nn.Module:
        """Main OFMU unlearning loop."""
        rho = self.rho_init
        
        for epoch in range(num_epochs):
            epoch_start = time.time()
            
            # Inner maximization
            phi_loss, similarity = self.inner_maximization_step(forget_loader, retain_loader)
            
            # Outer minimization  
            retain_loss, penalty = self.outer_minimization_step(retain_loader, rho)
            
            # Update penalty parameter
            rho = min(rho * 1.1, 1.0)  # Gradually increase penalty
            
            epoch_time = time.time() - epoch_start
            print(f"Epoch {epoch+1}/{num_epochs}: "
                  f"Phi_loss={phi_loss:.4f}, Retain_loss={retain_loss:.4f}, "
                  f"Similarity={similarity:.4f}, Penalty={penalty:.4f}, "
                  f"Time={epoch_time:.2f}s")
        
        return self.model


class CIFARExperiment:
    """CIFAR experiment runner for unlearning methods comparison."""
    
    def __init__(self, config: Dict):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.setup_logging()
        self.setup_data()
        
    def setup_logging(self):
        """Setup logging configuration."""
        log_file = f"cifar_experiment_{self.config['dataset']}_{self.config['forget_type']}.log"
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
        
    def setup_data(self):
        """Setup CIFAR datasets with train/test splits."""
        # Data transforms
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        
        # Load dataset
        if self.config["dataset"] == "cifar10":
            self.train_dataset = torchvision.datasets.CIFAR10(
                root='./data', train=True, download=True, transform=transform_train
            )
            self.test_dataset = torchvision.datasets.CIFAR10(
                root='./data', train=False, download=True, transform=transform_test
            )
            self.num_classes = 10
        else:  # cifar100
            self.train_dataset = torchvision.datasets.CIFAR100(
                root='./data', train=True, download=True, transform=transform_train
            )
            self.test_dataset = torchvision.datasets.CIFAR100(
                root='./data', train=False, download=True, transform=transform_test
            )
            self.num_classes = 100
            
        # Create forget and retain sets
        self.create_forget_retain_split()
        
    def create_forget_retain_split(self):
        """Create forget and retain sets based on configuration."""
        if self.config["forget_type"] == "class":
            # Class-wise forgetting
            forget_class = self.config["forget_class"]
            
            forget_indices = [i for i, (_, label) in enumerate(self.train_dataset) 
                            if label == forget_class]
            retain_indices = [i for i, (_, label) in enumerate(self.train_dataset) 
                            if label != forget_class]
                            
        else:  # random forgetting
            forget_ratio = self.config["forget_ratio"]
            total_samples = len(self.train_dataset)
            forget_size = int(total_samples * forget_ratio)
            
            np.random.seed(self.config["seed"])
            all_indices = np.arange(total_samples)
            np.random.shuffle(all_indices)
            
            forget_indices = all_indices[:forget_size].tolist()
            retain_indices = all_indices[forget_size:].tolist()
        
        self.forget_dataset = Subset(self.train_dataset, forget_indices)
        self.retain_dataset = Subset(self.train_dataset, retain_indices)
        
        self.logger.info(f"Created forget set: {len(self.forget_dataset)} samples")
        self.logger.info(f"Created retain set: {len(self.retain_dataset)} samples")
    
    def train_original_model(self) -> nn.Module:
        """Train the original model on full dataset."""
        model = ResNet18_CIFAR(num_classes=self.num_classes).to(self.device)
        
        # Training configuration
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
        
        # Data loader
        train_loader = DataLoader(self.train_dataset, batch_size=128, shuffle=True, num_workers=2)
        
        self.logger.info("Training original model...")
        for epoch in range(100):  # Standard CIFAR training epochs
            model.train()
            running_loss = 0.0
            correct = 0
            total = 0
            
            for batch_idx, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
            scheduler.step()
            
            if epoch % 20 == 0:
                acc = 100. * correct / total
                self.logger.info(f'Epoch {epoch}: Loss={running_loss/len(train_loader):.3f}, Acc={acc:.2f}%')
        
        return model
    
    def evaluate_model(self, model: nn.Module, test_name: str = "") -> Dict:
        """Evaluate model on forget, retain, and test sets."""
        model.eval()
        results = {}
        
        # Test sets to evaluate
        test_sets = {
            "forget": DataLoader(self.forget_dataset, batch_size=128, shuffle=False),
            "retain": DataLoader(self.retain_dataset, batch_size=128, shuffle=False),
            "test": DataLoader(self.test_dataset, batch_size=128, shuffle=False)
        }
        
        with torch.no_grad():
            for set_name, loader in test_sets.items():
                correct = 0
                total = 0
                all_preds = []
                all_targets = []
                
                for inputs, targets in loader:
                    inputs, targets = inputs.to(self.device), targets.to(self.device)
                    outputs = model(inputs)
                    _, predicted = outputs.max(1)
                    
                    total += targets.size(0)
                    correct += predicted.eq(targets).sum().item()
                    
                    all_preds.extend(predicted.cpu().numpy())
                    all_targets.extend(targets.cpu().numpy())
                
                accuracy = 100. * correct / total
                results[f"{set_name}_accuracy"] = accuracy
                
                self.logger.info(f"{test_name} {set_name.title()} Accuracy: {accuracy:.2f}%")
        
        # Calculate specific metrics
        if self.config["forget_type"] == "class":
            results["unlearning_accuracy"] = 100 - results["forget_accuracy"]  # Lower is better
        else:
            results["unlearning_accuracy"] = 100 - results["forget_accuracy"]
            
        results["retain_accuracy"] = results["retain_accuracy"]
        results["test_accuracy"] = results["test_accuracy"]
        
        return results
    
    def run_ofmu_experiment(self, original_model: nn.Module) -> Dict:
        """Run OFMU unlearning experiment."""
        self.logger.info("Running OFMU unlearning...")
        
        # Create model copy
        model = ResNet18_CIFAR(num_classes=self.num_classes).to(self.device)
        model.load_state_dict(original_model.state_dict())
        
        # OFMU configuration
        ofmu_config = {
            "beta": 0.1,
            "rho_init": 0.01,
            "inner_steps": 5,
            "inner_lr": 0.001,
            "outer_lr": 0.001
        }
        
        # Initialize trainer
        trainer = OFMUVisionTrainer(model, self.device, ofmu_config)
        
        # Create data loaders
        forget_loader = DataLoader(self.forget_dataset, batch_size=32, shuffle=True)
        retain_loader = DataLoader(self.retain_dataset, batch_size=32, shuffle=True)
        
        # Run unlearning
        unlearned_model = trainer.unlearn(forget_loader, retain_loader, num_epochs=10)
        
        # Evaluate
        results = self.evaluate_model(unlearned_model, "OFMU")
        results["method"] = "ofmu"
        
        return results
        
    def run_baseline_experiments(self, original_model: nn.Module) -> Dict[str, Dict]:
        """Run baseline unlearning methods."""
        baselines = {}
        
        # Fine-tuning baseline
        self.logger.info("Running Fine-tuning baseline...")
        ft_model = ResNet18_CIFAR(num_classes=self.num_classes).to(self.device)
        ft_model.load_state_dict(original_model.state_dict())
        
        # Fine-tune on retain set only
        retain_loader = DataLoader(self.retain_dataset, batch_size=128, shuffle=True)
        optimizer = optim.SGD(ft_model.parameters(), lr=0.01, momentum=0.9)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(10):
            ft_model.train()
            for inputs, targets in retain_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                optimizer.zero_grad()
                outputs = ft_model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
        
        baselines["finetuned"] = self.evaluate_model(ft_model, "Fine-tuning")
        baselines["finetuned"]["method"] = "finetuned"
        
        # Gradient Ascent baseline
        self.logger.info("Running Gradient Ascent baseline...")
        ga_model = ResNet18_CIFAR(num_classes=self.num_classes).to(self.device)
        ga_model.load_state_dict(original_model.state_dict())
        
        forget_loader = DataLoader(self.forget_dataset, batch_size=128, shuffle=True)
        optimizer = optim.SGD(ga_model.parameters(), lr=0.01)
        
        for epoch in range(5):
            ga_model.train()
            for inputs, targets in forget_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                optimizer.zero_grad()
                outputs = ga_model(inputs)
                loss = -criterion(outputs, targets)  # Negative for maximization
                loss.backward()
                optimizer.step()
        
        baselines["gradient_ascent"] = self.evaluate_model(ga_model, "Gradient Ascent")
        baselines["gradient_ascent"]["method"] = "gradient_ascent"
        
        return baselines
    
    def run_experiment(self) -> Dict:
        """Run complete CIFAR experiment."""
        # Train original model
        original_model = self.train_original_model()
        original_results = self.evaluate_model(original_model, "Original")
        
        # Run OFMU
        ofmu_results = self.run_ofmu_experiment(original_model)
        
        # Run baselines
        baseline_results = self.run_baseline_experiments(original_model)
        
        # Combine results
        all_results = {
            "original": original_results,
            "ofmu": ofmu_results,
            **baseline_results
        }
        
        # Save results
        results_path = Path(f"results/cifar_{self.config['dataset']}_{self.config['forget_type']}.json")
        results_path.parent.mkdir(exist_ok=True)
        
        with open(results_path, 'w') as f:
            json.dump(all_results, f, indent=2)
            
        self.logger.info(f"Results saved to {results_path}")
        return all_results


def main():
    parser = argparse.ArgumentParser(description="CIFAR Experiments for OFMU")
    parser.add_argument("--dataset", choices=["cifar10", "cifar100"], default="cifar10",
                       help="Dataset to use")
    parser.add_argument("--forget_type", choices=["class", "random"], default="class",
                       help="Type of forgetting")
    parser.add_argument("--forget_class", type=int, default=0,
                       help="Class to forget (for class-wise forgetting)")
    parser.add_argument("--forget_ratio", type=float, default=0.1,
                       help="Ratio of data to forget (for random forgetting)")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    args = parser.parse_args()
    
    # Set random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    # Create configuration
    config = {
        "dataset": args.dataset,
        "forget_type": args.forget_type,
        "forget_class": args.forget_class,
        "forget_ratio": args.forget_ratio,
        "seed": args.seed
    }
    
    # Run experiment
    experiment = CIFARExperiment(config)
    results = experiment.run_experiment()
    
    # Print summary
    print("\n" + "="*60)
    print(f"CIFAR EXPERIMENT RESULTS - {args.dataset.upper()}")
    print("="*60)
    print(f"Forget Type: {args.forget_type}")
    if args.forget_type == "class":
        print(f"Forget Class: {args.forget_class}")
    else:
        print(f"Forget Ratio: {args.forget_ratio}")
    print("-"*60)
    print(f"{'Method':<15} {'UA':<8} {'RA':<8} {'TA':<8}")
    print("-"*60)
    
    for method, result in results.items():
        if method == "original":
            continue
        ua = result.get('unlearning_accuracy', 0)
        ra = result.get('retain_accuracy', 0)
        ta = result.get('test_accuracy', 0)
        print(f"{method:<15} {ua:<8.2f} {ra:<8.2f} {ta:<8.2f}")
    
    print("="*60)


if __name__ == "__main__":
    main()