import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import json
from datetime import datetime

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.lenet import LeNet
from models.resnet import ResNet9, ResNet18, CifarResNet18
from methods.our.unlearning import OurUnlearning
from utils import set_seed

EXPERIMENT_CONFIGS = [
    {'model': 'resnet18', 'dataset': 'cifar100', 'in_channels': 3, 'num_classes': 100},
]

def get_model(model_name, num_classes, in_channels):
    """Create model instance based on model name"""
    if model_name == 'lenet':
        return LeNet(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'resnet9':
        return ResNet9(num_classes=num_classes, in_channels=in_channels)
    elif model_name == 'resnet18':
        return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
    else:
        raise ValueError(f"Unsupported model: {model_name}")

def get_model_path(model_name, dataset_name):
    """Get model checkpoint path"""
    return f'checkpoints/{model_name}_{dataset_name}_best.pth'

def generate_coverage_ranges():
    """Generate coverage range list from [0.003,0.003] to [0.030,0.030] with step 0.003"""
    ranges = []
    for i in range(1, 11):
        coverage = round(i * 0.003, 3)
        ranges.append((coverage, coverage))
    return ranges

def coverage_ablation_single_experiment(config):
    """Coverage hyperparameter ablation experiment for single configuration"""
    
    print(f"Starting coverage hyperparameter ablation experiment")
    print(f"Model: {config['model']}")
    print(f"Dataset: {config['dataset']}")
    print(f"Forget class: {config['forget_class']}")
    print(f"Coverage ranges: {config['coverage_ranges']}")
    
    if not os.path.exists(config['model_path']):
        print(f"Error: Model file does not exist {config['model_path']}")
        print("Please train the model first or skip this configuration")
        return None
    
    set_seed(config['seed'])
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = f"results/coverage_ablation/{config['dataset']}_{config['model']}_class{config['forget_class']}_{timestamp}"
    os.makedirs(results_dir, exist_ok=True)
    
    results = {
        'coverage_ranges': config['coverage_ranges'],
        'coverage_values': [f"{r[0]:.2f}-{r[1]:.2f}" for r in config['coverage_ranges']],
        'forget_accuracies': [],
        'retain_accuracies': [],
        'forget_accuracy_changes': [],
        'retain_accuracy_changes': [],
        'config': config
    }
    
    for i, (coverage_min, coverage_max) in enumerate(config['coverage_ranges']):
        print(f"\n{'='*60}")
        print(f"Experiment {i+1}/{len(config['coverage_ranges'])}: Coverage range = [{coverage_min:.2f}, {coverage_max:.2f}]")
        print(f"{'='*60}")
        
        try:
            model = get_model(config['model'], config['num_classes'], config['in_channels'])
            
            log_dir = os.path.join(results_dir, f"coverage_{coverage_min:.2f}_{coverage_max:.2f}")
            unlearner = OurUnlearning(
                model=model,
                dataset_name=config['dataset'],
                checkpoint_path=config['model_path'],
                batch_size=config['batch_size'],
                device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                log_dir=log_dir,
                seed=config['seed'],
                alpha_range=(coverage_min, coverage_max)
            )
            
            forget_results = unlearner.unlearn_classes([config['forget_class']])
            
            if forget_results:
                result = forget_results[0]
                
                forget_acc = result['final']['class_test']['accuracy']
                retain_acc = result['final']['active_test']['accuracy']
                forget_change = result['forget_acc_change']
                retain_change = result['retain_acc_change']
                
                results['forget_accuracies'].append(forget_acc)
                results['retain_accuracies'].append(retain_acc)
                results['forget_accuracy_changes'].append(forget_change)
                results['retain_accuracy_changes'].append(retain_change)
                
                print(f"Coverage [{coverage_min:.2f}, {coverage_max:.2f}] results:")
                print(f"  Forget class accuracy: {forget_acc:.2f}%")
                print(f"  Retain class accuracy: {retain_acc:.2f}%")
                print(f"  Forget class accuracy change: {forget_change:.2f}%")
                print(f"  Retain class accuracy change: {retain_change:.2f}%")
            else:
                print(f"Coverage [{coverage_min:.2f}, {coverage_max:.2f}] experiment failed, using default values")
                results['forget_accuracies'].append(0.0)
                results['retain_accuracies'].append(0.0)
                results['forget_accuracy_changes'].append(0.0)
                results['retain_accuracy_changes'].append(0.0)
                
        except Exception as e:
            print(f"Coverage [{coverage_min:.2f}, {coverage_max:.2f}] experiment error: {e}")
            results['forget_accuracies'].append(0.0)
            results['retain_accuracies'].append(0.0)
            results['forget_accuracy_changes'].append(0.0)
            results['retain_accuracy_changes'].append(0.0)
    
    results_file = os.path.join(results_dir, 'results.json')
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\nExperiment results saved to: {results_file}")
    
    plot_results(results, results_dir)
    
    return results

def coverage_ablation_all_experiments(experiment_indices=None):
    """Execute all or specified coverage hyperparameter ablation experiments"""
    
    if experiment_indices is None:
        experiment_indices = list(range(len(EXPERIMENT_CONFIGS)))
    
    all_results = {}
    
    for idx in experiment_indices:
        if idx >= len(EXPERIMENT_CONFIGS):
            print(f"Warning: Experiment index {idx} out of range, skipping")
            continue
            
        exp_config = EXPERIMENT_CONFIGS[idx].copy()
        
        exp_config.update({
            'forget_class': 0,
            'coverage_ranges': generate_coverage_ranges(),
            'seed': 42,
            'batch_size': 128
        })
        
        exp_config['model_path'] = get_model_path(exp_config['model'], exp_config['dataset'])
        
        print(f"\n{'='*80}")
        print(f"Executing experiment {idx+1}/{len(EXPERIMENT_CONFIGS)}: {exp_config['model']} on {exp_config['dataset']}")
        print(f"{'='*80}")
        
        try:
            results = coverage_ablation_single_experiment(exp_config)
            if results:
                exp_key = f"{exp_config['model']}_{exp_config['dataset']}"
                all_results[exp_key] = results
                print(f"Experiment {exp_key} completed successfully")
            else:
                print(f"Experiment {exp_config['model']}_{exp_config['dataset']} failed or skipped")
        except Exception as e:
            print(f"Experiment {exp_config['model']}_{exp_config['dataset']} error: {e}")
            continue
    
    if all_results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        summary_dir = f"results/coverage_ablation/summary_{timestamp}"
        os.makedirs(summary_dir, exist_ok=True)
        
        summary_file = os.path.join(summary_dir, 'all_experiments_summary.json')
        with open(summary_file, 'w') as f:
            json.dump(all_results, f, indent=2)
        
        print(f"\nAll experiment results summary saved to: {summary_file}")
        
        generate_summary_report(all_results, summary_dir)
    
    return all_results

def generate_summary_report(all_results, save_dir):
    """Generate summary report for all experiments"""
    
    report_lines = ["Coverage Ablation Experiment Summary Report", "=" * 50, ""]
    
    for exp_name, results in all_results.items():
        if not results['forget_accuracies']:
            continue
            
        coverage_values = results['coverage_values']
        forget_accs = results['forget_accuracies']
        retain_accs = results['retain_accuracies']
        forget_changes = results['forget_accuracy_changes']
        retain_changes = results['retain_accuracy_changes']
        
        best_coverage_idx = 0
        best_score = float('inf')
        
        for i in range(len(coverage_values)):
            forget_score = forget_accs[i]
            retain_penalty = abs(retain_changes[i])
            combined_score = forget_score + retain_penalty * 2
            
            if combined_score < best_score:
                best_score = combined_score
                best_coverage_idx = i
        
        best_coverage = coverage_values[best_coverage_idx]
        
        report_lines.extend([
            f"Experiment: {exp_name}",
            f"Best coverage range: {best_coverage}",
            f"Forget class accuracy: {forget_accs[best_coverage_idx]:.2f}%",
            f"Retain class accuracy: {retain_accs[best_coverage_idx]:.2f}%",
            f"Forget class accuracy change: {forget_changes[best_coverage_idx]:.2f}%",
            f"Retain class accuracy change: {retain_changes[best_coverage_idx]:.2f}%",
            "-" * 40, ""
        ])
    
    report_file = os.path.join(save_dir, 'summary_report.txt')
    with open(report_file, 'w', encoding='utf-8') as f:
        f.write('\n'.join(report_lines))
    
    print(f"Summary report saved to: {report_file}")

def plot_results(results, save_dir):
    """Plot only accuracy curves (forget and retain classes) with dataset and model in title"""
    coverage_values = [float(cv.split('-')[0]) for cv in results['coverage_values']]
    forget_accs = results['forget_accuracies']
    retain_accs = results['retain_accuracies']
    retain_changes = results['retain_accuracy_changes']

    model = results['config']['model']
    dataset = results['config']['dataset']

    plt.figure(figsize=(8, 6))
    plt.plot(coverage_values, forget_accs, 'r-o', label='Forget Class Accuracy', linewidth=2, markersize=6)
    plt.plot(coverage_values, retain_accs, 'b-s', label='Retain Classes Accuracy', linewidth=2, markersize=6)
    plt.xlabel('Coverage Rate', fontsize=12)
    plt.ylabel('Accuracy (%)', fontsize=12)
    plt.title(f'Accuracy vs Coverage Rate ({model} on {dataset})', fontsize=14)
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_file = os.path.join(save_dir, 'coverage_accuracy_plot.png')
    plt.savefig(plot_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    plt.figure(figsize=(12, 6))
    
    combined_scores = []
    for i in range(len(coverage_values)):
        forget_score = forget_accs[i]
        retain_penalty = abs(retain_changes[i])
        combined_score = forget_score + retain_penalty * 2
        combined_scores.append(combined_score)
    
    plt.plot(coverage_values, combined_scores, 'g-o', linewidth=3, markersize=8, label='Combined Score')
    plt.xlabel('Coverage Rate', fontsize=12)
    plt.ylabel('Combined Score (Lower is Better)', fontsize=12)
    plt.title('Overall Performance vs Coverage Rate', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=11)
    
    best_idx = np.argmin(combined_scores)
    plt.annotate(f'Best: {coverage_values[best_idx]:.2f}', 
                xy=(coverage_values[best_idx], combined_scores[best_idx]),
                xytext=(10, 10), textcoords='offset points',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7),
                arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
    
    plt.tight_layout()
    
    combined_plot_file = os.path.join(save_dir, 'coverage_combined_score.png')
    plt.savefig(combined_plot_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Accuracy curve plot saved to: {plot_file}")
    print(f"Combined score plot saved to: {combined_plot_file}")

if __name__ == "__main__":
    print("Supported experiment configurations:")
    for i, config in enumerate(EXPERIMENT_CONFIGS):
        print(f"{i}: {config['model']} on {config['dataset']}")
    
    print(f"\nCoverage ranges: {generate_coverage_ranges()}")
    
    print("\nSelect execution mode:")
    print("1. Execute all experiments")
    print("2. Execute specific experiments")
    
    choice = input("Please enter your choice (1/2): ").strip()
    
    if choice == "1":
        results = coverage_ablation_all_experiments()
    elif choice == "2":
        indices_input = input("Please enter experiment indices (comma separated, e.g.: 0,1,2): ").strip()
        try:
            indices = [int(x.strip()) for x in indices_input.split(',')]
            results = coverage_ablation_all_experiments(indices)
        except ValueError:
            print("Invalid input format, executing all experiments")
            results = coverage_ablation_all_experiments()
    else:
        print("Invalid choice, executing all experiments")
        results = coverage_ablation_all_experiments()
