import itertools
import numpy as np
import torch
import yaml
import json
from datetime import datetime
import os
import sys
from pathlib import Path
import subprocess
import copy

sys.path.append(str(Path(__file__).parent))

param_grid = {
    'r': [0.5, 1.0, 1.5, 2.0],
    'kappa_feature': [0.5, 1.0, 1.5],
    'kappa_edge': [0.3, 0.8, 1.2],
    'kappa_sensitive': [1.0, 2.0, 3.0],
    'kappa_label': [0.1, 0.3, 0.5],
    'alpha': [0.3, 0.6, 0.9],
    'lambda_lip': [1.0, 1.5, 2.0],
}

seeds = [42, 123, 456, 789, 1024]

class GraphDROExperimentTracker:
    def __init__(self, output_dir='experiment_results_graphdro'):
        self.output_dir = output_dir
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.results = []
        os.makedirs(output_dir, exist_ok=True)
        
    def add_result(self, params, metrics):
        result = {
            'parameters': params,
            'metrics': metrics,
            'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        }
        self.results.append(result)
        self._save_results()
        print(f"Result saved: {len(self.results)} experiments completed")
        
    def _save_results(self):
        filename = f'{self.output_dir}/results_{self.timestamp}.json'
        with open(filename, 'w') as f:
            json.dump(self.results, f, indent=4)
    
    def get_best_results(self, metric='accuracy', top_k=5):
        sorted_results = sorted(self.results, key=lambda x: x['metrics']['clean_acc'], reverse=True)
        return sorted_results[:top_k]

def create_config_with_params(base_config_path, params, seed):
    with open(base_config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    config['seed'] = seed
    config['training']['r'] = params['r']
    config['training']['kappa']['feature'] = params['kappa_feature']
    config['training']['kappa']['edge'] = params['kappa_edge']
    config['training']['kappa']['sensitive'] = params['kappa_sensitive']
    config['training']['kappa']['label'] = params['kappa_label']
    config['training']['fairness']['alpha'] = params['alpha']
    config['training']['lipschitz']['lambda_lip'] = params['lambda_lip']
    
    temp_config_path = f'temp_config_{seed}.yaml'
    with open(temp_config_path, 'w') as f:
        yaml.dump(config, f, default_flow_style=False)
    
    return temp_config_path

def run_experiment(params, seed, base_config_path):
    print(f"Running experiment - seed: {seed}")
    print(f"Params: {params}")
    
    temp_config_path = create_config_with_params(base_config_path, params, seed)
    
    try:
        cmd = f"python main.py experiment --config {temp_config_path}"
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        
        if result.returncode == 0:
            output = result.stdout
            metrics = parse_experiment_output(output)
            metrics['seed'] = seed
            if 'clean_acc' in metrics:
                print(f"Experiment success - clean_acc: {metrics['clean_acc']:.4f}")
            return metrics
        else:
            print(f"Experiment failed: {result.stderr}")
            return None
            
    except Exception as e:
        print(f"Experiment error: {e}")
        return None
    finally:
        if os.path.exists(temp_config_path):
            os.remove(temp_config_path)

def parse_experiment_output(output):
    metrics = {}
    lines = output.split('\n')
    for line in lines:
        if 'Clean Acc:' in line:
            metrics['clean_acc'] = float(line.split(':')[1].strip())
        elif 'Attack Acc:' in line:
            metrics['attack_acc'] = float(line.split(':')[1].strip())
        elif '∆Acc:' in line:
            metrics['delta_acc'] = float(line.split(':')[1].strip())
        elif '∆SP:' in line:
            metrics['delta_sp'] = float(line.split(':')[1].strip())
        elif '∆EO:' in line:
            metrics['delta_eo'] = float(line.split(':')[1].strip())
    return metrics

def main():
    print("Starting GraphDRO hyperparameter sweep")
    
    base_config_path = 'configs/code_graphdro.yaml'
    tracker = GraphDROExperimentTracker()
    
    param_names = sorted(param_grid.keys())
    param_values = [param_grid[name] for name in param_names]
    param_combinations = list(itertools.product(*param_values))
    
    total_experiments = len(param_combinations) * len(seeds)
    print(f"Total experiments: {total_experiments}")
    print(f"Parameter combinations: {len(param_combinations)}")
    print(f"Random seeds: {len(seeds)}")
    
    completed = 0
    for i, values in enumerate(param_combinations):
        params = dict(zip(param_names, values))
        
        for seed in seeds:
            completed += 1
            print(f"\n{'='*60}")
            print(f"Progress: {completed}/{total_experiments}")
            print(f"Combination: {i+1}/{len(param_combinations)}")
            print(f"Seed: {seed}")
            print(f"{'='*60}")
            
            metrics = run_experiment(params, seed, base_config_path)
            
            if metrics:
                tracker.add_result(
                    params={**params, 'seed': seed},
                    metrics=metrics
                )
                best_results = tracker.get_best_results(top_k=3)
                print(f"\nCurrent top results:")
                for j, result in enumerate(best_results):
                    print(f"  {j+1}. clean_acc: {result['metrics']['clean_acc']:.4f}, "
                          f"r={result['parameters']['r']}, α={result['parameters']['alpha']}")
            else:
                print(f"Experiment failed, skipping...")
    
    print(f"Hyperparameter sweep completed!")
    print(f"Total experiments: {total_experiments}")
    print(f"Successful runs: {len(tracker.results)}")

    best_results = tracker.get_best_results(top_k=5)
    print(f"Final top results:")
    for i, result in enumerate(best_results):
        print(f"  {i+1}. clean_acc: {result['metrics']['clean_acc']:.4f}")
        print(f"     params: {result['parameters']}")
        print(f"     metrics: {result['metrics']}")
        print()

if __name__ == "__main__":
    main()