import os
import json
import yaml
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, Any, List, Tuple
from datetime import datetime
import itertools
import argparse
from dataclasses import dataclass

@dataclass
class SearchResult:
    hyperparams: Dict[str, Any]
    baseline_clean_acc: float
    baseline_attack_acc: float
    baseline_delta_acc: float
    baseline_sp_gap: float
    baseline_eo_gap: float
    graphdro_clean_acc: float
    graphdro_attack_acc: float
    graphdro_delta_acc: float
    graphdro_sp_gap: float
    graphdro_eo_gap: float
    clean_acc_improvement: float
    attack_acc_improvement: float
    robustness_improvement: float
    sp_improvement: float
    eo_improvement: float
    overall_score: float
    baseline_model_path: str
    graphdro_model_path: str
    experiment_id: str
    timestamp: str


class HyperparameterSearcher:
    def __init__(self, config_path: str, dataset: str, base_seed: int = 42):
        self.config_path = config_path
        self.dataset = dataset
        self.base_seed = base_seed
        self.base_config = load_config(config_path, dataset)
        model_name = self.base_config.get('model', {}).get('name', 'unknown')
        self.results_dir = Path(f"hyperparameter_search_results/{dataset}_{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
        self.results_dir.mkdir(parents=True, exist_ok=True)
        self.search_space = self.define_search_space()
        self.results: List[SearchResult] = []
        self.best_result: SearchResult = None
        self.best_score = -float('inf')

    def define_search_space(self) -> Dict[str, List]:
        return {
            'r': [0.5, 1.0, 1.5],
            'eta': [0.3, 0.5, 0.7],
            'kappa_feature': [0.5, 1.0, 1.2],
            'kappa_sensitive': [0.3, 0.5, 0.8],
            'kappa_edge': [0.3, 0.5, 0.8],
            'kappa_label': [0.1, 0.2, 0.3],
            'alpha': [1.0, 2.0, 4.0],
            'beta': [0.01, 0.05, 0.1],
            'lambda_lip': [0.5, 1.0, 1.5],
            'eps_e': [0.05, 0.1, 0.15],
            'eps_x': [0.02, 0.05, 0.08],
            'eps_l': [0.1, 0.15, 0.2],
            'gamma': [0.2, 0.3, 0.4],
            'lr': [0.001, 0.003, 0.01],
            'hidden_dim': [32, 64, 128],
            'dropout': [0.3, 0.5, 0.7],
            'weight_decay': [0, 1e-5],
            'num_layers': [2, 3],
        }

    def generate_hyperparameter_combinations(self, max_combinations: int = 100) -> List[Dict[str, Any]]:
        total_combinations = 1
        for values in self.search_space.values():
            total_combinations *= len(values)
        if total_combinations <= max_combinations:
            all_combinations = list(itertools.product(*self.search_space.values()))
            combinations = all_combinations
        else:
            combinations = []
            param_values = list(self.search_space.values())
            np.random.seed(self.base_seed)
            for _ in range(max_combinations):
                combo = []
                for values in param_values:
                    combo.append(np.random.choice(values))
                combinations.append(combo)
        param_names = list(self.search_space.keys())
        return [dict(zip(param_names, combo)) for combo in combinations]

    def update_config_with_hyperparams(self, config: Dict[str, Any], hyperparams: Dict[str, Any]) -> Dict[str, Any]:
        config = config.copy()
        if 'r' in hyperparams:
            config['training']['r'] = hyperparams['r']
        if 'eta' in hyperparams:
            config['training']['eta'] = hyperparams['eta']
        if 'kappa_feature' in hyperparams:
            config['training']['kappa']['feature'] = hyperparams['kappa_feature']
        if 'kappa_sensitive' in hyperparams:
            config['training']['kappa']['sensitive'] = hyperparams['kappa_sensitive']
        if 'kappa_edge' in hyperparams:
            config['training']['kappa']['edge'] = hyperparams['kappa_edge']
        if 'kappa_label' in hyperparams:
            config['training']['kappa']['label'] = hyperparams['kappa_label']
        if 'alpha' in hyperparams:
            config['training']['fairness']['alpha'] = hyperparams['alpha']
        if 'beta' in hyperparams:
            config['training']['fairness']['beta'] = hyperparams['beta']
        if 'lambda_lip' in hyperparams:
            config['training']['lipschitz']['lambda_lip'] = hyperparams['lambda_lip']
        if 'lr' in hyperparams:
            config['optimizer']['lr'] = hyperparams['lr']
        if 'weight_decay' in hyperparams:
            config['optimizer']['weight_decay'] = hyperparams['weight_decay']
        if 'hidden_dim' in hyperparams:
            config['model']['hidden_dim'] = hyperparams['hidden_dim']
        if 'dropout' in hyperparams:
            config['model']['dropout'] = hyperparams['dropout']
        if 'num_layers' in hyperparams:
            config['model']['num_layers'] = hyperparams['num_layers']
        return config

    def calculate_score(self, result: Dict[str, Any]) -> float:
        try:
            if 'stats' in result['baseline']:
                baseline_stats = result['baseline']['stats']
                graphdro_stats = result['graphdro']['stats']
            else:
                baseline_stats = result['baseline']
                graphdro_stats = result['graphdro']
            baseline_score = 0
            baseline_clean_acc = baseline_stats['clean']['accuracy']['mean']
            baseline_score += baseline_clean_acc * 100
            graphdro_score = 0
            graphdro_clean_acc = graphdro_stats['clean']['accuracy']['mean']
            graphdro_attack_acc = graphdro_stats['attack']['accuracy']['mean']
            graphdro_score += graphdro_clean_acc * 100
            graphdro_score += graphdro_attack_acc * 80
            graphdro_delta = graphdro_clean_acc - graphdro_attack_acc
            if graphdro_delta > 0:
                graphdro_score += (1.0 - graphdro_delta) * 150
            else:
                graphdro_score -= abs(graphdro_delta) * 200
            fairness_score = 0
            sp_improvement = baseline_stats['clean']['sp_gap']['mean'] - graphdro_stats['clean']['sp_gap']['mean']
            eo_improvement = baseline_stats['clean']['eo_gap']['mean'] - graphdro_stats['clean']['eo_gap']['mean']
            if sp_improvement > 0:
                fairness_score += sp_improvement * 400
            else:
                fairness_score -= abs(sp_improvement) * 500
            if eo_improvement > 0:
                fairness_score += eo_improvement * 400
            else:
                fairness_score -= abs(eo_improvement) * 500
            total_score = baseline_score + graphdro_score + fairness_score
            if (graphdro_clean_acc > baseline_clean_acc and
                sp_improvement > 0 and eo_improvement > 0):
                total_score += 300
            if graphdro_delta < 0.05:
                total_score += 200
            return total_score
        except:
            return -1000.0

    def save_model(self, model, experiment_id: str, model_type: str) -> str:
        model_dir = self.results_dir / "models" / experiment_id
        model_dir.mkdir(parents=True, exist_ok=True)
        model_path = model_dir / f"{model_type}_model.pth"
        torch.save(model.state_dict(), model_path)
        return str(model_path)

    def run_single_experiment(self, hyperparams: Dict[str, Any], experiment_id: str) -> SearchResult:
        config = self.update_config_with_hyperparams(self.base_config, hyperparams)
        set_seed(self.base_seed)
        try:
            attack_params = {
                'eps_e': hyperparams.get('eps_e', 0.3),
                'eps_x': hyperparams.get('eps_x', 0.1),
                'eps_l': hyperparams.get('eps_l', 0.05),
                'gamma': hyperparams.get('gamma', 0.1)
            }
            result = comparison_experiment(
                config=config,
                num_runs=3,
                num_epochs=50,
                attack_params=attack_params
            )
            score = self.calculate_score(result)
            try:
                if 'stats' in result['baseline']:
                    baseline_stats = result['baseline']['stats']
                    graphdro_stats = result['graphdro']['stats']
                else:
                    baseline_stats = result['baseline']
                    graphdro_stats = result['graphdro']
            except:
                return None
            search_result = SearchResult(
                hyperparams=hyperparams,
                baseline_clean_acc=baseline_stats['clean']['accuracy']['mean'],
                baseline_attack_acc=baseline_stats['attack']['accuracy']['mean'],
                baseline_delta_acc=baseline_stats['clean']['accuracy']['mean'] - baseline_stats['attack']['accuracy']['mean'],
                baseline_sp_gap=baseline_stats['clean']['sp_gap']['mean'],
                baseline_eo_gap=baseline_stats['clean']['eo_gap']['mean'],
                graphdro_clean_acc=graphdro_stats['clean']['accuracy']['mean'],
                graphdro_attack_acc=graphdro_stats['attack']['accuracy']['mean'],
                graphdro_delta_acc=graphdro_stats['clean']['accuracy']['mean'] - graphdro_stats['attack']['accuracy']['mean'],
                graphdro_sp_gap=graphdro_stats['clean']['sp_gap']['mean'],
                graphdro_eo_gap=graphdro_stats['clean']['eo_gap']['mean'],
                clean_acc_improvement=graphdro_stats['clean']['accuracy']['mean'] - baseline_stats['clean']['accuracy']['mean'],
                attack_acc_improvement=graphdro_stats['attack']['accuracy']['mean'] - baseline_stats['attack']['accuracy']['mean'],
                robustness_improvement=baseline_stats['clean']['accuracy']['mean'] - baseline_stats['attack']['accuracy']['mean'] -
                                     (graphdro_stats['clean']['accuracy']['mean'] - graphdro_stats['attack']['accuracy']['mean']),
                sp_improvement=baseline_stats['clean']['sp_gap']['mean'] - graphdro_stats['clean']['sp_gap']['mean'],
                eo_improvement=baseline_stats['clean']['eo_gap']['mean'] - graphdro_stats['clean']['eo_gap']['mean'],
                overall_score=score,
                baseline_model_path="",
                graphdro_model_path="",
                experiment_id=experiment_id,
                timestamp=datetime.now().isoformat()
            )
            return search_result
        except:
            return None

    def search(self, max_combinations: int = 50, save_interval: int = 10):
        hyperparam_combinations = self.generate_hyperparameter_combinations(max_combinations)
        successful_experiments = 0
        failed_experiments = 0
        for i, hyperparams in enumerate(hyperparam_combinations):
            experiment_id = f"exp_{i:04d}"
            try:
                result = self.run_single_experiment(hyperparams, experiment_id)
                if result is not None:
                    self.results.append(result)
                    successful_experiments += 1
                    if result.overall_score > self.best_score:
                        self.best_score = result.overall_score
                        self.best_result = result
                    if (i + 1) % save_interval == 0:
                        self.save_results()
            except Exception as e:
                failed_experiments += 1
        self.save_results()
        self.save_best_result()

    def save_results(self):
        if not self.results:
            empty_df = pd.DataFrame(columns=[
                'experiment_id', 'timestamp', 'overall_score',
                'r', 'eta', 'kappa_feature', 'kappa_sensitive', 'kappa_edge', 'kappa_label',
                'alpha', 'beta', 'lambda_lip', 'eps_e', 'eps_x', 'eps_l', 'gamma',
                'lr', 'weight_decay', 'hidden_dim', 'dropout', 'num_layers',
                'baseline_clean_acc', 'baseline_attack_acc', 'baseline_delta_acc', 'baseline_sp_gap', 'baseline_eo_gap',
                'graphdro_clean_acc', 'graphdro_attack_acc', 'graphdro_delta_acc', 'graphdro_sp_gap', 'graphdro_eo_gap',
                'clean_acc_improvement', 'attack_acc_improvement', 'robustness_improvement', 'sp_improvement', 'eo_improvement'
            ])
            csv_path = self.results_dir / "all_results.csv"
            empty_df.to_csv(csv_path, index=False)
            return
        results_data = []
        for result in self.results:
            results_data.append({
                'experiment_id': result.experiment_id,
                'timestamp': result.timestamp,
                'overall_score': result.overall_score,
                **result.hyperparams,
                'baseline_clean_acc': result.baseline_clean_acc,
                'baseline_attack_acc': result.baseline_attack_acc,
                'baseline_delta_acc': result.baseline_delta_acc,
                'baseline_sp_gap': result.baseline_sp_gap,
                'baseline_eo_gap': result.baseline_eo_gap,
                'graphdro_clean_acc': result.graphdro_clean_acc,
                'graphdro_attack_acc': result.graphdro_attack_acc,
                'graphdro_delta_acc': result.graphdro_delta_acc,
                'graphdro_sp_gap': result.graphdro_sp_gap,
                'graphdro_eo_gap': result.graphdro_eo_gap,
                'clean_acc_improvement': result.clean_acc_improvement,
                'attack_acc_improvement': result.attack_acc_improvement,
                'robustness_improvement': result.robustness_improvement,
                'sp_improvement': result.sp_improvement,
                'eo_improvement': result.eo_improvement,
            })
        df = pd.DataFrame(results_data)
        csv_path = self.results_dir / "all_results.csv"
        df.to_csv(csv_path, index=False)
        json_path = self.results_dir / "all_results.json"
        with open(json_path, 'w') as f:
            json.dump(results_data, f, indent=2, default=str)

    def save_best_result(self):
        if self.best_result is None:
            return
        best_config = self.update_config_with_hyperparams(self.base_config, self.best_result.hyperparams)
        best_config_path = self.results_dir / "best_config.yaml"
        with open(best_config_path, 'w') as f:
            yaml.dump(best_config, f, default_flow_style=False)
        best_result_path = self.results_dir / "best_result.json"
        with open(best_result_path, 'w') as f:
            json.dump({
                'experiment_id': self.best_result.experiment_id,
                'timestamp': self.best_result.timestamp,
                'overall_score': self.best_result.overall_score,
                'hyperparams': self.best_result.hyperparams,
                'baseline_results': {
                    'clean_acc': self.best_result.baseline_clean_acc,
                    'attack_acc': self.best_result.baseline_attack_acc,
                    'delta_acc': self.best_result.baseline_delta_acc,
                    'sp_gap': self.best_result.baseline_sp_gap,
                    'eo_gap': self.best_result.baseline_eo_gap,
                },
                'graphdro_results': {
                    'clean_acc': self.best_result.graphdro_clean_acc,
                    'attack_acc': self.best_result.graphdro_attack_acc,
                    'delta_acc': self.best_result.graphdro_delta_acc,
                    'sp_gap': self.best_result.graphdro_sp_gap,
                    'eo_gap': self.best_result.graphdro_eo_gap,
                },
                'improvements': {
                    'clean_acc_improvement': self.best_result.clean_acc_improvement,
                    'attack_acc_improvement': self.best_result.attack_acc_improvement,
                    'robustness_improvement': self.best_result.robustness_improvement,
                    'sp_improvement': self.best_result.sp_improvement,
                    'eo_improvement': self.best_result.eo_improvement,
                }
            }, f, indent=2, default=str)

    def generate_report(self):
        if not self.results:
            return
        param_importance = self.analyze_parameter_importance()
        report = f"""
# Hyperparameter Search Report
## Search Overview
- Dataset: {self.dataset}
- Total Experiments: {len(self.results)}
- Search Duration: {self.results[0].timestamp} to {self.results[-1].timestamp}
## Best Result
- Experiment ID: {self.best_result.experiment_id}
- Overall Score: {self.best_score:.2f}
### Best Hyperparameters
"""
        for key, value in self.best_result.hyperparams.items():
            report += f"- {key}: {value}\n"
        report += f"""
### Performance Comparison
| Metric | Baseline Model | GraphDRO | Improvement |
|------|----------------|----------|-------------|
| Clean Acc | {self.best_result.baseline_clean_acc:.4f} | {self.best_result.graphdro_clean_acc:.4f} | {self.best_result.clean_acc_improvement:+.4f} |
| Attack Acc | {self.best_result.baseline_attack_acc:.4f} | {self.best_result.graphdro_attack_acc:.4f} | {self.best_result.attack_acc_improvement:+.4f} |
| ΔAcc | {self.best_result.baseline_delta_acc:.4f} | {self.best_result.graphdro_delta_acc:.4f} | {self.best_result.robustness_improvement:+.4f} |
| ΔSP | {self.best_result.baseline_sp_gap:.4f} | {self.best_result.graphdro_sp_gap:.4f} | {self.best_result.sp_improvement:+.4f} |
| ΔEO | {self.best_result.baseline_eo_gap:.4f} | {self.best_result.graphdro_eo_gap:.4f} | {self.best_result.eo_improvement:+.4f} |
## Result Analysis
"""
        for param, importance in param_importance.items():
            report += f"- {param}: {importance:.3f}\n"
        report_path = self.results_dir / "search_report.md"
        with open(report_path, 'w') as f:
            f.write(report)

    def analyze_parameter_importance(self) -> Dict[str, float]:
        if len(self.results) < 2:
            return {}
        param_importance = {}
        for param in self.search_space.keys():
            if param in self.results[0].hyperparams:
                param_values = [r.hyperparams[param] for r in self.results]
                scores = [r.overall_score for r in self.results]
                correlation = np.corrcoef(param_values, scores)[0, 1]
                if np.isnan(correlation):
                    correlation = 0
                param_importance[param] = abs(correlation)
        param_importance = dict(sorted(param_importance.items(), key=lambda x: x[1], reverse=True))
        return param_importance


def main():
    parser = argparse.ArgumentParser(description="Hyperparameter Search")
    parser.add_argument("--config", type=str, default="configs/graphdro.yaml", help="Path to config file")
    parser.add_argument("--dataset", type=str, required=True, help="Dataset name")
    parser.add_argument("--max-combinations", type=int, default=50, help="Maximum number of hyperparameter combinations")
    parser.add_argument("--save-interval", type=int, default=10, help="Save interval")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    args = parser.parse_args()
    searcher = HyperparameterSearcher(
        config_path=args.config,
        dataset=args.dataset,
        base_seed=args.seed
    )
    searcher.search(
        max_combinations=args.max_combinations,
        save_interval=args.save_interval
    )
    searcher.generate_report()


if __name__ == "__main__":
    main()