import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import json
from datetime import datetime

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.lenet import LeNet
from models.resnet import ResNet9, ResNet18,CifarResNet18
from methods.our.unlearning import OurUnlearning
from utils import set_seed

EXPERIMENT_CONFIGS = [
    {'model': 'resnet18', 'dataset': 'cifar10', 'in_channels': 3, 'num_classes': 10},
    {'model': 'resnet18', 'dataset': 'cifar100', 'in_channels': 3, 'num_classes': 100},
]

def get_model(model_name, num_classes, in_channels):
    """Create model instance based on model name"""
    if model_name == 'lenet':
        return LeNet(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'resnet9':
        return ResNet9(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'resnet18':
        return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
    else:
        raise ValueError(f"Unsupported model: {model_name}")

def get_model_path(model_name, dataset_name):
    """Get model checkpoint path"""
    return f'checkpoints/{model_name}_{dataset_name}_best.pth'

def lambda_ablation_single_experiment(config):
    """Lambda hyperparameter ablation experiment for single configuration"""
    
    print(f"Starting Lambda hyperparameter ablation experiment")
    print(f"Model: {config['model']}")
    print(f"Dataset: {config['dataset']}")
    print(f"Forget class: {config['forget_class']}")
    print(f"Lambda range: {config['lambda_values']}")
    
    if not os.path.exists(config['model_path']):
        print(f"Error: Model file does not exist {config['model_path']}")
        print("Please train the model first or skip this configuration")
        return None
    
    set_seed(config['seed'])
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f"results/lambda_ablation/{config['dataset']}_{config['model']}_class{config['forget_class']}_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)
    
    results = {
        'lambda_values': config['lambda_values'],
        'forget_accuracies': [],
        'retain_accuracies': [],
        'forget_accuracy_changes': [],
        'retain_accuracy_changes': [],
        'config': config
    }
    
    for i, lambda_val in enumerate(config['lambda_values']):
        print(f"\n{'='*60}")
        print(f"Experiment {i+1}/{len(config['lambda_values'])}: Lambda = {lambda_val}")
        print(f"{'='*60}")
        
        try:
            model = get_model(config['model'], config['num_classes'], config['in_channels'])
            
            log_dir = os.path.join(results_dir, f"lambda_{lambda_val}")
            unlearner = OurUnlearning(
                model=model,
                dataset_name=config['dataset'],
                checkpoint_path=config['model_path'],
                batch_size=config['batch_size'],
                device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                log_dir=log_dir,
                seed=config['seed']
            )
            
            unlearner.lambda_value = lambda_val
            
            forget_results = unlearner.unlearn_classes([config['forget_class']])
            
            if forget_results:
                result = forget_results[0]
                
                forget_acc = result['final']['class_test']['accuracy']
                retain_acc = result['final']['active_test']['accuracy']
                forget_change = result['forget_acc_change']
                retain_change = result['retain_acc_change']
                
                results['forget_accuracies'].append(forget_acc)
                results['retain_accuracies'].append(retain_acc)
                results['forget_accuracy_changes'].append(forget_change)
                results['retain_accuracy_changes'].append(retain_change)
                
                print(f"Lambda {lambda_val} results:")
                print(f"  Forget class accuracy: {forget_acc:.2f}%")
                print(f"  Retain class accuracy: {retain_acc:.2f}%")
                print(f"  Forget class accuracy change: {forget_change:.2f}%")
                print(f"  Retain class accuracy change: {retain_change:.2f}%")
            else:
                print(f"Lambda {lambda_val} experiment failed, using default values")
                results['forget_accuracies'].append(0.0)
                results['retain_accuracies'].append(0.0)
                results['forget_accuracy_changes'].append(0.0)
                results['retain_accuracy_changes'].append(0.0)
                
        except Exception as e:
            print(f"Lambda {lambda_val} experiment error: {e}")
            results['forget_accuracies'].append(0.0)
            results['retain_accuracies'].append(0.0)
            results['forget_accuracy_changes'].append(0.0)
            results['retain_accuracy_changes'].append(0.0)
    
    results_file = os.path.join(results_dir, 'results.json')
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\nExperiment results saved to: {results_file}")
    
    plot_results(results, results_dir)
    
    return results

def lambda_ablation_all_experiments(experiment_indices=None):
    """Execute all or specified Lambda hyperparameter ablation experiments"""
    
    if experiment_indices is None:
        experiment_indices = list(range(len(EXPERIMENT_CONFIGS)))
    
    all_results = {}
    
    for idx in experiment_indices:
        if idx >= len(EXPERIMENT_CONFIGS):
            print(f"Warning: Experiment index {idx} out of range, skipping")
            continue
            
        exp_config = EXPERIMENT_CONFIGS[idx].copy()
        
        exp_config.update({
            'forget_class': 0,
            'lambda_values': list(range(1, 21)),
            'seed': 42,
            'batch_size': 128
        })
        
        exp_config['model_path'] = get_model_path(exp_config['model'], exp_config['dataset'])
        
        print(f"\n{'='*80}")
        print(f"Executing experiment {idx+1}/{len(EXPERIMENT_CONFIGS)}: {exp_config['model']} on {exp_config['dataset']}")
        print(f"{'='*80}")
        
        try:
            results = lambda_ablation_single_experiment(exp_config)
            if results:
                exp_key = f"{exp_config['model']}_{exp_config['dataset']}"
                all_results[exp_key] = results
                print(f"Experiment {exp_key} completed successfully")
            else:
                print(f"Experiment {exp_config['model']}_{exp_config['dataset']} failed or skipped")
        except Exception as e:
            print(f"Experiment {exp_config['model']}_{exp_config['dataset']} error: {e}")
            continue
    
    if all_results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        summary_dir = f"results/lambda_ablation/summary_{timestamp}"
        os.makedirs(summary_dir, exist_ok=True)
        
        summary_file = os.path.join(summary_dir, 'all_experiments_summary.json')
        with open(summary_file, 'w') as f:
            json.dump(all_results, f, indent=2)
        
        print(f"\nAll experiment results summary saved to: {summary_file}")
        
        generate_summary_report(all_results, summary_dir)
    
    return all_results

def generate_summary_report(all_results, save_dir):
    """Generate summary report for all experiments"""
    
    report_lines = ["Lambda Ablation Experiment Summary Report", "=" * 50, ""]
    
    for exp_name, results in all_results.items():
        if not results['forget_accuracies']:
            continue
            
        lambda_values = results['lambda_values']
        forget_accs = results['forget_accuracies']
        retain_accs = results['retain_accuracies']
        forget_changes = results['forget_accuracy_changes']
        retain_changes = results['retain_accuracy_changes']
        
        best_lambda_idx = 0
        best_score = float('inf')
        
        for i in range(len(lambda_values)):
            forget_score = forget_accs[i]
            retain_penalty = abs(retain_changes[i])
            combined_score = forget_score + retain_penalty * 2
            
            if combined_score < best_score:
                best_score = combined_score
                best_lambda_idx = i
        
        best_lambda = lambda_values[best_lambda_idx]
        
        report_lines.extend([
            f"Experiment: {exp_name}",
            f"Best Lambda: {best_lambda}",
            f"Forget class accuracy: {forget_accs[best_lambda_idx]:.2f}%",
            f"Retain class accuracy: {retain_accs[best_lambda_idx]:.2f}%",
            f"Forget class accuracy change: {forget_changes[best_lambda_idx]:.2f}%",
            f"Retain class accuracy change: {retain_changes[best_lambda_idx]:.2f}%",
            "-" * 40, ""
        ])
    
    report_file = os.path.join(save_dir, 'summary_report.txt')
    with open(report_file, 'w', encoding='utf-8') as f:
        f.write('\n'.join(report_lines))
    
    print(f"Summary report saved to: {report_file}")

def plot_results(results, save_dir):
    """Plot experiment results"""
    
    lambda_values = results['lambda_values']
    forget_accs = results['forget_accuracies']
    retain_accs = results['retain_accuracies']
    forget_changes = results['forget_accuracy_changes']
    retain_changes = results['retain_accuracy_changes']
    
    plt.figure(figsize=(10, 6))
    
    plt.plot(lambda_values, forget_accs, 'r-o', label='Forget Class Accuracy', linewidth=2, markersize=6)
    plt.plot(lambda_values, retain_accs, 'b-s', label='Retain Classes Accuracy', linewidth=2, markersize=6)
    
    plt.xlabel('Lambda Value', fontsize=12)
    plt.ylabel('Accuracy (%)', fontsize=12)
    plt.title('Effect of Lambda on Forget and Retain Class Accuracies', fontsize=14)
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.xticks(lambda_values[::2])
    
    all_accs = forget_accs + retain_accs
    y_min = max(0, min(all_accs) - 5)
    y_max = min(100, max(all_accs) + 5)
    plt.ylim(y_min, y_max)
    
    plt.tight_layout()
    
    plot_file = os.path.join(save_dir, 'lambda_accuracy_plot.png')
    plt.savefig(plot_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Result plot saved to: {plot_file}")


if __name__ == "__main__":
    print("Supported experiment configurations:")
    for i, config in enumerate(EXPERIMENT_CONFIGS):
        print(f"{i}: {config['model']} on {config['dataset']}")
    
    print("\nSelect execution mode:")
    print("1. Execute all experiments")
    print("2. Execute specific experiments")
    
    choice = input("Please enter your choice (1/2): ").strip()
    
    if choice == "1":
        results = lambda_ablation_all_experiments()
    elif choice == "2":
        indices_input = input("Please enter experiment indices (comma separated, e.g: 0,1,2): ").strip()
        try:
            indices = [int(x.strip()) for x in indices_input.split(',')]
            results = lambda_ablation_all_experiments(indices)
        except ValueError:
            print("Invalid input format, executing all experiments")
            results = lambda_ablation_all_experiments()
    else:
        print("Invalid choice, executing all experiments")
        results = lambda_ablation_all_experiments()
