import os
import json
import time
import random
import numpy as np
import torch
from datetime import datetime
import pandas as pd
from typing import Dict, List, Any
import traceback

from data_preprocessing import build_id_ood_from_config  
from models.MLP import main_energy_mlp_deep
from main_bl import BL

class ExperimentRunner:
    def __init__(self, results_dir: str = "experiment_results_v4"):

        self.results_dir = results_dir
        self.create_results_directory()
        
        self.datasets = ["agnews", "yelp"]
        self.random_seeds = [42, 123, 456, 789, 2024]
        
        self.models = {
            # Baseline MLP models
            "mlp_depth1": {
                "type": "baseline_mlp",
                "hidden_sizes": [1024],
                "description": "MLP with depth=1"
            },
            "mlp_depth2": {
                "type": "baseline_mlp", 
                "hidden_sizes": [1000, 256],
                "description": "MLP with depth=2"
            },
            "mlp_depth3": {
                "type": "baseline_mlp",
                "hidden_sizes": [512, 256, 128],
                "description": "MLP with depth=3"
            },
            
            # BL models
            "bl_shallow_depth1": {
                "type": "bl_shallow",
                "layer_n_sub": [380],
                "description": "Shallow BL with depth=1"
            },
            "bl_shallow_depth2": {
                "type": "bl_shallow",
                "layer_n_sub": [512, 128],
                "description": "Shallow BL with depth=2"
            },
            "bl_shallow_depth3": {
                "type": "bl_shallow",
                "layer_n_sub": [256,128,64],
                "description": "Shallow BL with depth=3"
            },
        }
        
        self.batch_configs = {
            "agnews": {"train": 128, "test": 256},
            "yelp": {"train": 2048, "test": 8192}
        }
        
        self.results = []
        self.failed_experiments = []
    
    def create_results_directory(self):
        if not os.path.exists(self.results_dir):
            os.makedirs(self.results_dir)

        subdirs = ["logs", "detailed_results", "summary"]
        for subdir in subdirs:
            path = os.path.join(self.results_dir, subdir)
            if not os.path.exists(path):
                os.makedirs(path)
    
    def set_random_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    def get_data_config(self, id_dataset: str, ood_dataset: str) -> Dict:
        return {
            "data": {
                "id_dataset": id_dataset,
                "ood_dataset": ood_dataset,
                "id_dir": ".",
                "ood_dir": ".",
                "batch_size_per_dataset": {
                    id_dataset: self.batch_configs[id_dataset],
                    ood_dataset: self.batch_configs[ood_dataset]
                },
                "num_workers": 0,
                "pin_memory": True
            }
        }
    
    def run_baseline_mlp(self, bundle, model_config: Dict, seed: int) -> Dict:
        print(f"    Running baseline MLP: {model_config['description']}")
        
        metrics = main_energy_mlp_deep(
            train_loader=bundle.train_loader,
            test_loader=bundle.test_loader,
            input_dim=bundle.in_dim,
            num_classes=bundle.num_classes,
            hidden_sizes=model_config["hidden_sizes"],
            ood_loader=bundle.ood_loader
        )
        
        return metrics
    
    def run_bl_model(self, bundle, model_config: Dict, seed: int) -> Dict:
        print(f"    Running BL model: {model_config['description']}")
        
        model, metrics = BL(
            train_loader=bundle.train_loader,
            x_test=bundle.x_test,
            y_test=bundle.y_test,
            in_dim=bundle.in_dim,
            layer_n_sub=model_config["layer_n_sub"],
            num_classes=bundle.num_classes,
            max_epochs=15,
            x_ood_test=bundle.x_ood_test 
        )
        
        if isinstance(metrics, list) and len(metrics) > 0:
            final_metrics = metrics[-1].copy()  
            return final_metrics
        else:
            return {}
    
    def run_single_experiment(self, dataset: str, model_name: str, model_config: Dict, seed: int) -> Dict:
        experiment_start_time = time.time()

        self.set_random_seed(seed)
        
        ood_dataset = "yelp" if dataset == "agnews" else "agnews"
        cfg = self.get_data_config(dataset, ood_dataset)
        
        bundle = build_id_ood_from_config(cfg)
        
        if model_config["type"] == "baseline_mlp":
            metrics = self.run_baseline_mlp(bundle, model_config, seed)
        else:  # BL models
            metrics = self.run_bl_model(bundle, model_config, seed)
        
        experiment_time = time.time() - experiment_start_time
        
        result = {
            "dataset": dataset,
            "ood_dataset": ood_dataset,
            "model_name": model_name,
            "model_description": model_config["description"],
            "model_type": model_config["type"],
            "seed": seed,
            "experiment_time": experiment_time,
            "timestamp": datetime.now().isoformat(),
            "success": True
        }
        
        if isinstance(metrics, dict):
            result.update(metrics)
        else:
            result["metrics"] = metrics
            
        return result
    
    @staticmethod
    def _json_safe(o):
        if isinstance(o, (np.floating,)):
            return float(o)
        if isinstance(o, (np.integer,)):
            return int(o)
        if isinstance(o, (np.bool_,)):
            return bool(o)
        if isinstance(o, np.ndarray):
            return o.tolist()
        try:
            import torch
            if isinstance(o, torch.Tensor):
                return o.detach().cpu().tolist()
        except ImportError:
            pass
        return str(o)

    def save_experiment_log(self, result: dict):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{result['dataset']}_{result['model_name']}_seed{result['seed']}_{timestamp}.json"
        filepath = os.path.join(self.results_dir, "detailed_results", filename)
        
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(result, f, indent=2, ensure_ascii=False, default=ExperimentRunner._json_safe)

    
    def run_all_experiments(self):
        total_experiments = len(self.datasets) * len(self.models) * len(self.random_seeds)
        current_experiment = 0

        
        overall_start_time = time.time()
        
        for dataset in self.datasets:
            
            for model_name, model_config in self.models.items():
                print(f"\n  model: {model_name} ({model_config['description']})")
                
                for seed in self.random_seeds:
                    current_experiment += 1
                    print(f"  experiment {current_experiment}/{total_experiments} - seed: {seed}")
                    result = self.run_single_experiment(dataset, model_name, model_config, seed)
                    self.results.append(result)
                    self.save_experiment_log(result)
        
        overall_time = time.time() - overall_start_time
        
        print(f"All experiments completed!")
        print(f"Total time: {overall_time:.2f}s ({overall_time/60:.2f} minutes)")

        self.save_final_results()
        self.generate_summary_report()
    
    def save_final_results(self):
        filepath = os.path.join(self.results_dir, "final_results.json")
        data = {
            "results": self.results,
            "total_experiments": len(self.results),
            "timestamp": datetime.now().isoformat()
        }
        
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(data, f, indent=2, ensure_ascii=False, default=ExperimentRunner._json_safe)

        successful_results = [r for r in self.results if r["success"]]
        if successful_results:
            df = pd.DataFrame(successful_results)
            csv_filepath = os.path.join(self.results_dir, "results.csv")
            df.to_csv(csv_filepath, index=False)
    
    def generate_summary_report(self):
        successful_results = [r for r in self.results if r["success"]]
        df = pd.DataFrame(successful_results)

        summary_stats = []
        
        for dataset in self.datasets:
            for model_name in self.models.keys():
                subset = df[(df['dataset'] == dataset) & (df['model_name'] == model_name)]
                
                if len(subset) == 0:
                    continue

                numeric_columns = subset.select_dtypes(include=[np.number]).columns
                numeric_columns = [col for col in numeric_columns if col not in ['seed', 'experiment_time']]
                
                for metric in numeric_columns:
                    if metric in subset.columns:
                        values = subset[metric].values
                        stats = {
                            'dataset': dataset,
                            'model_name': model_name,
                            'model_description': self.models[model_name]['description'],
                            'metric': metric,
                            'mean': np.mean(values),
                            'std': np.std(values, ddof=1),
                            'count': len(values)
                        }
                        summary_stats.append(stats)
        
        if summary_stats:
            summary_df = pd.DataFrame(summary_stats)
            summary_filepath = os.path.join(self.results_dir, "summary", "summary_statistics.csv")
            summary_df.to_csv(summary_filepath, index=False)

            self.generate_readable_report(summary_df)
    
def main():
    runner = ExperimentRunner()
    
    print("begin experiments...")
    runner.run_all_experiments()

if __name__ == "__main__":
    main()