import torch
import numpy as np
import random
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import logging
import os
import time
from torch.utils.tensorboard import SummaryWriter

def set_seed(seed):
    """Set all random seeds to ensure reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def show_images(dataloader, num_images=16, title="Sample Images"):
    """Display images from dataset"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch = next(iter(dataloader))
    images, labels = batch
    
    images = images[:num_images].to(device)
    
    grid = vutils.make_grid(images, nrow=4, normalize=True)
    plt.figure(figsize=(10, 10))
    plt.title(title)
    plt.imshow(grid.cpu().numpy().transpose((1, 2, 0)))
    plt.axis('off')
    plt.show()

def count_parameters(model):
    """Count the number of model parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def accuracy(output, target, topk=(1,)):
    """Calculate topk accuracy"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def setup_logger(name, log_dir, level=logging.INFO):
    """Setup logger

    Args:
        name: Logger name
        log_dir: Log file directory
        level: Log level, default is INFO

    Returns:
        tuple: (logger, log_file) Logger and log file path
    """
    log_file = None
    
    logger = logging.getLogger(name)
    logger.setLevel(level)
    
    if logger.handlers:
        return logger, log_file
    
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    
    console_handler = logging.StreamHandler()
    console_handler.setLevel(level)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)
    
    if log_dir:
        os.makedirs(log_dir, exist_ok=True)
        log_file = os.path.join(log_dir, f"{name}.log")
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(level)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    
    return logger, log_file

def setup_tensorboard(log_dir):
    """Setup TensorBoard writer
    
    Args:
        log_dir: TensorBoard log directory
        
    Returns:
        writer: TensorBoard writer
    """
    os.makedirs(log_dir, exist_ok=True)
    return SummaryWriter(log_dir=log_dir)

def print_log_info(log_file, tensorboard_dir):
    """Print log location and TensorBoard usage tips
    
    Args:
        log_file: Log file path
        tensorboard_dir: TensorBoard log directory
    """
    print("\n" + "="*80)
    print(f"Log file saved to: {log_file}")
    print(f"TensorBoard logs saved to: {tensorboard_dir}")
    print(f"\nTo view TensorBoard visualization, run in terminal:")
    print(f"tensorboard --logdir={tensorboard_dir}")
    print("="*80 + "\n")

def summarize_unlearning_results(results, logger):
    """Summarize unlearning results and log them
    
    Args:
        results: List of unlearning results
        logger: Logger instance
    """
    avg_forget_change = np.mean([r['forget_acc_change'] for r in results])
    avg_retain_change = np.mean([r['retain_acc_change'] for r in results])

    logger.info("=" * 50)
    logger.info("Unlearning Summary Results:")
    logger.info(f"Average forget class accuracy change: {avg_forget_change:.2f}%")
    logger.info(f"Average retain class accuracy change: {avg_retain_change:.2f}%")
    logger.info("=" * 50)

    logger.info("\nClass Details:")
    logger.info(f"{'Class':<10} {'Initial Acc':<12} {'Final Acc':<12} {'Change':<10}")
    logger.info("-" * 45)

    for r in results:
        name = r['class_name']
        init_acc = r['initial']['class_test']['accuracy']
        final_acc = r['final']['class_test']['accuracy']
        change = r['forget_acc_change']
        logger.info(f"{name:<10} {init_acc:<12.2f} {final_acc:<12.2f} {change:<10.2f}")