import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import json
from datetime import datetime

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.lenet import LeNet
from models.resnet import ResNet9, ResNet18, CifarResNet18
from methods.our.unlearning import OurUnlearning
from utils import set_seed

EXPERIMENT_CONFIGS = [
    {'model': 'resnet9', 'dataset': 'mnist', 'in_channels': 1, 'num_classes': 10},
    {'model': 'lenet', 'dataset': 'mnist', 'in_channels': 1, 'num_classes': 10},
    {'model': 'lenet', 'dataset': 'svhn', 'in_channels': 3, 'num_classes': 10},
    {'model': 'resnet9', 'dataset': 'svhn', 'in_channels': 3, 'num_classes': 10},
    {'model': 'lenet', 'dataset': 'cifar10', 'in_channels': 3, 'num_classes': 10},
    {'model': 'resnet9', 'dataset': 'cifar10', 'in_channels': 3, 'num_classes': 10},
    {'model': 'resnet18', 'dataset': 'cifar10', 'in_channels': 3, 'num_classes': 10},
    {'model': 'resnet9', 'dataset': 'cifar100', 'in_channels': 3, 'num_classes': 100},
    {'model': 'resnet18', 'dataset': 'cifar100', 'in_channels': 3, 'num_classes': 100},
]

def get_model(model_name, num_classes, in_channels):
    """Create model instance based on model name"""
    if model_name == 'lenet':
        return LeNet(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'resnet9':
        return ResNet9(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'resnet18':
        return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
    else:
        raise ValueError(f"Unsupported model: {model_name}")

def get_model_path(model_name, dataset_name):
    """Get model checkpoint path"""
    return f'checkpoints/{model_name}_{dataset_name}_best.pth'

def plot_class_unlearning_results(results, save_dir):
    """Plot class unlearning results for a single experiment"""
    
    class_indices = results['class_indices']
    forget_accs = results['forget_accuracies']
    retain_accs = results['retain_accuracies']
    config = results['config']
    
    print(f"Plot data validation:")
    print(f"  Number of classes: {len(class_indices)}")
    print(f"  Number of forget accuracies: {len(forget_accs)}")
    print(f"  Number of retain accuracies: {len(retain_accs)}")
    print(f"  Forget accuracy range: {min(forget_accs):.2f}% - {max(forget_accs):.2f}%")
    print(f"  Retain accuracy range: {min(retain_accs):.2f}% - {max(retain_accs):.2f}%")
    
    plt.figure(figsize=(12, 8))
    
    min_len = min(len(class_indices), len(forget_accs), len(retain_accs))
    if min_len < len(class_indices):
        print(f"Warning: Inconsistent data length, truncating to {min_len} data points")
        class_indices = class_indices[:min_len]
        forget_accs = forget_accs[:min_len]
        retain_accs = retain_accs[:min_len]
    
    plt.plot(class_indices, forget_accs, 'r-o', label='Forget Class Accuracy', 
             linewidth=2, markersize=8, markerfacecolor='red', markeredgecolor='darkred', alpha=0.8)
    plt.plot(class_indices, retain_accs, 'b-s', label='Retain Classes Accuracy', 
             linewidth=2, markersize=8, markerfacecolor='blue', markeredgecolor='darkblue', alpha=0.8)
    
    plt.xlabel('Class Index', fontsize=14)
    plt.ylabel('Accuracy (%)', fontsize=14)
    plt.title(f'Single Class Unlearning Results\n{config["model"].upper()} on {config["dataset"].upper()}', 
              fontsize=16, fontweight='bold')
    plt.legend(fontsize=12, loc='best')
    plt.grid(True, alpha=0.3)
    
    if len(class_indices) <= 20:
        plt.xticks(class_indices)
    else:
        tick_step = max(1, len(class_indices) // 20)
        plt.xticks(class_indices[::tick_step])
    
    all_accs = forget_accs + retain_accs
    valid_accs = [acc for acc in all_accs if acc >= 0]
    if valid_accs:
        y_min = max(0, min(valid_accs) - 2)
        y_max = min(100, max(valid_accs) + 2)
        plt.ylim(y_min, y_max)
    else:
        plt.ylim(0, 100)
    
    if len(class_indices) <= 10:
        for i, (x, forget_y, retain_y) in enumerate(zip(class_indices, forget_accs, retain_accs)):
            plt.annotate(f'{forget_y:.1f}', (x, forget_y), textcoords="offset points", 
                        xytext=(0,10), ha='center', fontsize=9, color='red')
            plt.annotate(f'{retain_y:.1f}', (x, retain_y), textcoords="offset points", 
                        xytext=(0,-15), ha='center', fontsize=9, color='blue')
    
    plt.tight_layout()
    
    plot_file = os.path.join(save_dir, 'class_unlearning_results.png')
    plt.savefig(plot_file, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    
    print(f"Results plot saved to: {plot_file}")

def plot_all_experiments_summary(all_results, save_dir):
    """Plot summary charts for all experiments"""
    
    n_experiments = len(all_results)
    n_cols = min(3, n_experiments)
    n_rows = (n_experiments + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))
    if n_experiments == 1:
        axes = [axes]
    elif n_rows == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for idx, (exp_name, results) in enumerate(all_results.items()):
        if idx >= len(axes):
            break
            
        ax = axes[idx]
        
        class_indices = results['class_indices']
        forget_accs = results['forget_accuracies']
        retain_accs = results['retain_accuracies']
        config = results['config']
        
        min_len = min(len(class_indices), len(forget_accs), len(retain_accs))
        class_indices = class_indices[:min_len]
        forget_accs = forget_accs[:min_len]
        retain_accs = retain_accs[:min_len]
        
        ax.plot(class_indices, forget_accs, 'r-o', label='Forget Class Accuracy', 
                linewidth=2, markersize=6, alpha=0.8)
        ax.plot(class_indices, retain_accs, 'b-s', label='Retain Classes Accuracy', 
                linewidth=2, markersize=6, alpha=0.8)
        
        ax.set_xlabel('Class Index', fontsize=12)
        ax.set_ylabel('Accuracy (%)', fontsize=12)
        ax.set_title(f'{config["model"].upper()} on {config["dataset"].upper()}', 
                     fontsize=14, fontweight='bold')
        ax.legend(fontsize=10, loc='best')
        ax.grid(True, alpha=0.3)
        
        if len(class_indices) <= 10:
            ax.set_xticks(class_indices)
        else:
            tick_step = max(1, len(class_indices) // 10)
            ax.set_xticks(class_indices[::tick_step])
        
        all_accs = forget_accs + retain_accs
        valid_accs = [acc for acc in all_accs if acc >= 0]
        if valid_accs:
            y_min = max(0, min(valid_accs) - 2)
            y_max = min(100, max(valid_accs) + 2)
            ax.set_ylim(y_min, y_max)
    
    for idx in range(n_experiments, len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    
    summary_plot_file = os.path.join(save_dir, 'all_experiments_summary.png')
    plt.savefig(summary_plot_file, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    
    print(f"Summary plot saved to: {summary_plot_file}")

def single_class_unlearning_experiment(config):
    """Complete single class unlearning experiment for a single configuration"""
    
    print(f"Starting single class unlearning experiment")
    print(f"Model: {config['model']}")
    print(f"Dataset: {config['dataset']}")
    print(f"Number of classes: {config['num_classes']}")
    
    if not os.path.exists(config['model_path']):
        print(f"Error: Model file does not exist {config['model_path']}")
        print("Please train the model first or skip this configuration")
        return None
    
    set_seed(config['seed'])
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f"results/all_class_unlearning/{config['dataset']}_{config['model']}_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)
    
    results = {
        'class_indices': list(range(config['num_classes'])),
        'forget_accuracies': [],
        'retain_accuracies': [],
        'forget_accuracy_changes': [],
        'retain_accuracy_changes': [],
        'config': config
    }
    
    for class_idx in range(config['num_classes']):
        print(f"\n{'='*60}")
        print(f"Forgetting class {class_idx}/{config['num_classes']-1}")
        print(f"{'='*60}")
        
        try:
            model = get_model(config['model'], config['num_classes'], config['in_channels'])
            
            log_dir = os.path.join(results_dir, f"forget_class_{class_idx}")
            unlearner = OurUnlearning(
                model=model,
                dataset_name=config['dataset'],
                checkpoint_path=config['model_path'],
                batch_size=config['batch_size'],
                device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                log_dir=log_dir,
                seed=config['seed']
            )
            
            forget_results = unlearner.unlearn_classes([class_idx])
            
            if forget_results:
                result = forget_results[0]
                
                forget_acc = result['final']['class_test']['accuracy']
                retain_acc = result['final']['active_test']['accuracy']
                forget_change = result['forget_acc_change']
                retain_change = result['retain_acc_change']
                
                print(f"Data validation - Class {class_idx}:")
                print(f"  Forget class accuracy: {forget_acc:.2f}%")
                print(f"  Retain class accuracy: {retain_acc:.2f}%")
                
                results['forget_accuracies'].append(forget_acc)
                results['retain_accuracies'].append(retain_acc)
                results['forget_accuracy_changes'].append(forget_change)
                results['retain_accuracy_changes'].append(retain_change)
                
                print(f"Class {class_idx} unlearning results:")
                print(f"  Forget class accuracy: {forget_acc:.2f}%")
                print(f"  Retain class accuracy: {retain_acc:.2f}%")
                print(f"  Forget class accuracy change: {forget_change:.2f}%")
                print(f"  Retain class accuracy change: {retain_change:.2f}%")
            else:
                print(f"Class {class_idx} unlearning failed, using default values")
                results['forget_accuracies'].append(0.0)
                results['retain_accuracies'].append(0.0)
                results['forget_accuracy_changes'].append(0.0)
                results['retain_accuracy_changes'].append(0.0)
                
        except Exception as e:
            print(f"Class {class_idx} unlearning experiment error: {e}")
            results['forget_accuracies'].append(0.0)
            results['retain_accuracies'].append(0.0)
            results['forget_accuracy_changes'].append(0.0)
            results['retain_accuracy_changes'].append(0.0)
    
    print(f"\nFinal results validation:")
    print(f"Forget class accuracies: {results['forget_accuracies']}")
    print(f"Retain class accuracies: {results['retain_accuracies']}")
    
    results_file = os.path.join(results_dir, 'results.json')
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\nExperiment results saved to: {results_file}")
    
    plot_class_unlearning_results(results, results_dir)
    
    return results

def all_class_unlearning_experiments(experiment_indices=None):
    """Execute all or specified single class unlearning experiments"""
    
    if experiment_indices is None:
        experiment_indices = list(range(len(EXPERIMENT_CONFIGS)))
    
    all_results = {}
    
    for idx in experiment_indices:
        if idx >= len(EXPERIMENT_CONFIGS):
            print(f"Warning: Experiment index {idx} out of range, skipping")
            continue
            
        exp_config = EXPERIMENT_CONFIGS[idx].copy()
        
        exp_config.update({
            'seed': 42,
            'batch_size': 128
        })
        
        exp_config['model_path'] = get_model_path(exp_config['model'], exp_config['dataset'])
        
        print(f"\n{'='*80}")
        print(f"Executing experiment {idx+1}/{len(EXPERIMENT_CONFIGS)}: {exp_config['model']} on {exp_config['dataset']}")
        print(f"{'='*80}")
        
        try:
            results = single_class_unlearning_experiment(exp_config)
            if results:
                exp_key = f"{exp_config['model']}_{exp_config['dataset']}"
                all_results[exp_key] = results
                print(f"Experiment {exp_key} completed successfully")
            else:
                print(f"Experiment {exp_config['model']}_{exp_config['dataset']} failed or skipped")
        except Exception as e:
            print(f"Experiment {exp_config['model']}_{exp_config['dataset']} error: {e}")
            continue
    
    if all_results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        summary_dir = f"results/all_class_unlearning/summary_{timestamp}"
        os.makedirs(summary_dir, exist_ok=True)
        
        summary_file = os.path.join(summary_dir, 'all_experiments_summary.json')
        with open(summary_file, 'w') as f:
            json.dump(all_results, f, indent=2)
        
        print(f"\nAll experiment results summary saved to: {summary_file}")
        
        generate_summary_report(all_results, summary_dir)
        plot_all_experiments_summary(all_results, summary_dir)
    
    return all_results

def generate_summary_report(all_results, save_dir):
    """Generate summary report for all experiments"""
    
    report_lines = ["Single Class Unlearning Experiment Summary Report", "=" * 50, ""]
    
    for exp_name, results in all_results.items():
        if not results['forget_accuracies']:
            continue
            
        config = results['config']
        forget_accs = results['forget_accuracies']
        retain_accs = results['retain_accuracies']
        forget_changes = results['forget_accuracy_changes']
        retain_changes = results['retain_accuracy_changes']
        
        valid_forget_accs = [acc for acc in forget_accs if acc > 0]
        valid_retain_accs = [acc for acc in retain_accs if acc > 0]
        valid_forget_changes = [change for change in forget_changes if change != 0]
        valid_retain_changes = [change for change in retain_changes if change != 0]
        
        if not valid_forget_accs:
            continue
            
        avg_forget_acc = np.mean(valid_forget_accs)
        avg_retain_acc = np.mean(valid_retain_accs)
        avg_forget_change = np.mean(valid_forget_changes) if valid_forget_changes else 0
        avg_retain_change = np.mean(valid_retain_changes) if valid_retain_changes else 0
        
        min_forget_acc = np.min(valid_forget_accs)
        max_forget_acc = np.max(valid_forget_accs)
        min_retain_acc = np.min(valid_retain_accs)
        max_retain_acc = np.max(valid_retain_accs)
        
        report_lines.extend([
            f"Experiment: {exp_name}",
            f"Model: {config['model']} | Dataset: {config['dataset']} | Classes: {config['num_classes']}",
            "",
            f"Forget class accuracy statistics:",
            f"  Average: {avg_forget_acc:.2f}%",
            f"  Minimum: {min_forget_acc:.2f}%",
            f"  Maximum: {max_forget_acc:.2f}%",
            f"  Average change: {avg_forget_change:.2f}%",
            "",
            f"Retain class accuracy statistics:",
            f"  Average: {avg_retain_acc:.2f}%",
            f"  Minimum: {min_retain_acc:.2f}%",
            f"  Maximum: {max_retain_acc:.2f}%",
            f"  Average change: {avg_retain_change:.2f}%",
            "",
            f"Successfully forgotten classes: {len(valid_forget_accs)}/{config['num_classes']}",
            "-" * 60, ""
        ])
    
    report_file = os.path.join(save_dir, 'summary_report.txt')
    with open(report_file, 'w', encoding='utf-8') as f:
        f.write('\n'.join(report_lines))
    
    print(f"Summary report saved to: {report_file}")

if __name__ == "__main__":
    print("Supported experiment configurations:")
    for i, config in enumerate(EXPERIMENT_CONFIGS):
        print(f"{i}: {config['model']} on {config['dataset']} ({config['num_classes']} classes)")
    
    print("\nSelect execution mode:")
    print("1. Execute all experiments")
    print("2. Execute specific experiments")
    
    choice = input("Please enter your choice (1/2): ").strip()
    
    if choice == "1":
        results = all_class_unlearning_experiments()
    elif choice == "2":
        indices_input = input("Please enter experiment indices (comma separated, e.g.: 0,1,2): ").strip()
        try:
            indices = [int(x.strip()) for x in indices_input.split(',')]
            results = all_class_unlearning_experiments(indices)
        except ValueError:
            print("Invalid input format, executing all experiments")
            results = all_class_unlearning_experiments()
    else:
        print("Invalid choice, executing all experiments")
        results = all_class_unlearning_experiments()
