import os
import csv
import time
import random
import numpy as np
import torch
from typing import Dict, List, Any, Tuple
import traceback
from datetime import datetime

from data.data_utils import build_id_ood_from_config  
from high_dimension_exp.models.MLP import main_energy_mlp
from main_bl import BL


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class ExperimentRunner:
    def __init__(self):
        self.results = []
        self.experiment_configs = self._create_experiment_configs()
        self.seeds = [42, 123, 456, 789, 999]  
        
    def _create_experiment_configs(self) -> List[Dict]:
        configs = []
        dataset_configs = [
            {
                "name": "MNIST_vs_FashionMNIST",
                "config": {
                    "data": {
                        "id_dataset": "mnist",
                        "ood_dataset": "fashion_mnist",
                        "num_workers": 2,
                        "pin_memory": True
                    }
                }
            },
            {
                "name": "FashionMNIST_vs_MNIST", 
                "config": {
                    "data": {
                        "id_dataset": "fashion_mnist",
                        "ood_dataset": "mnist",
                        "num_workers": 2,
                        "pin_memory": True
                    }
                }
            }
        ]

        model_configs = [
            {
                "name": "MLP_depth1",
                "type": "mlp",
                "params": {"hidden_sizes": [256]}
            },
            {
                "name": "MLP_depth2", 
                "type": "mlp",
                "params": {"hidden_sizes": [256, 128]}
            },
            {
                "name": "MLP_depth3", 
                "type": "mlp",
                "params": {"hidden_sizes": [256, 128, 32]}
            },
            {
                "name": "BL_shallow_depth1",
                "type": "bl",
                "params": {
                    "layer_n_sub": [88],
                    "max_epochs": 50
                }
            },
            {
                "name": "BL_shallow_depth2",
                "type": "bl", 
                "params": {
                    "layer_n_sub": [88,42],
                    "max_epochs": 50
                }
            },
            {
                "name": "BL_shallow_depth3",
                "type": "bl", 
                "params": {
                    "layer_n_sub": [88,42,20],
                    "max_epochs": 50
                }
            },
           
        ]
        for dataset_config in dataset_configs:
            for model_config in model_configs:
                configs.append({
                    "dataset": dataset_config,
                    "model": model_config
                })
                
        return configs
    
    def run_single_experiment(self, dataset_config: Dict, model_config: Dict, seed: int) -> Dict[str, Any]:
        print(f"Running experiment: {model_config['name']} on {dataset_config['name']} with seed {seed}")
        set_seed(seed)
        bundle = build_id_ood_from_config(dataset_config['config'])
        
        start_time = time.time()
        
        if model_config['type'] == 'mlp':
            metrics = main_energy_mlp(
                train_loader=bundle.train_loader,
                test_loader=bundle.test_loader,
                input_dim=bundle.in_dim,
                num_classes=bundle.num_classes,
                hidden_sizes=model_config['params']['hidden_sizes'],
                ood_loader=bundle.ood_loader
            )
            
        elif model_config['type'] == 'bl':
            model, metrics = BL(
                train_loader=bundle.train_loader,
                x_test=bundle.x_test,
                y_test=bundle.y_test,
                in_dim=bundle.in_dim,
                layer_n_sub=model_config['params']['layer_n_sub'],
                num_classes=bundle.num_classes,
                max_epochs=model_config['params']['max_epochs'],
                x_ood_test=bundle.x_ood_test
            )
        
        runtime = time.time() - start_time

        result = {
            'dataset': dataset_config['name'],
            'model': model_config['name'],
            'seed': seed,
            'status': 'success',
            'runtime': runtime,
            'timestamp': datetime.now().isoformat()
        }

        if isinstance(metrics, list) and len(metrics) > 0:
            if isinstance(metrics[0], dict):
                last_metrics = metrics[-1]
                for key, value in last_metrics.items():
                    if hasattr(value, 'item'):
                        result[key] = value.item()
                    else:
                        result[key] = value
            else:
                result['metrics'] = str(metrics)
        elif isinstance(metrics, dict):
            for key, value in metrics.items():
                if hasattr(value, 'item'):
                    result[key] = value.item()
                else:
                    result[key] = value
        else:
            result['metrics'] = str(metrics)           
        return result
        
    def run_all_experiments(self, start_from: int = 0, end: int = None) -> None:
        print(f"Starting {len(self.experiment_configs) * len(self.seeds)} experiments...")
    
        total_experiments = len(self.experiment_configs) * len(self.seeds)
        completed = 0
        
        for config in self.experiment_configs:
            for seed in self.seeds:
                completed += 1

                if completed <= start_from:
                    continue
                if end is not None and completed > end:
                    return

                result = self.run_single_experiment(
                    config['dataset'],
                    config['model'],
                    seed
                )
                self.results.append(result)
                print(f"Completed {completed}/{total_experiments} experiments")

        print(f"Completed all {total_experiments} experiments!")

    
    def save_results(self, filename: str = None) -> None:
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"experiment_results_{timestamp}.csv"
        
        all_keys = set()
        for result in self.results:
            all_keys.update(result.keys())        
        fieldnames = sorted(list(all_keys))
       
        with open(filename, 'w', newline='', encoding='utf-8') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(self.results)
            
        print(f"Results saved to: {filename}")
    
    def calculate_statistics(self) -> Dict[str, Dict[str, float]]:

        stats = {}

        grouped_results = {}
        for result in self.results:
            if result.get('status') != 'success':
                continue
                
            key = f"{result['model']}_{result['dataset']}"
            if key not in grouped_results:
                grouped_results[key] = []
            grouped_results[key].append(result)

        for key, results in grouped_results.items():
            if len(results) < 2:
                continue
                
            stats[key] = {}

            metric_keys = set()
            for result in results:
                metric_keys.update([k for k in result.keys() 
                                  if k not in ['dataset', 'model', 'seed', 'status', 'runtime', 'timestamp']])
            
            for metric_key in metric_keys:
                values = []
                for result in results:
                    if metric_key in result and isinstance(result[metric_key], (int, float)):
                        values.append(float(result[metric_key]))
                
                if len(values) > 0:
                    mean = np.mean(values)
                    std_err = np.std(values) / np.sqrt(len(values))
                    stats[key][metric_key] = {
                        'mean': mean,
                        'std_err': std_err,
                        'count': len(values)
                    }
        
        return stats
    
    def save_statistics(self, filename: str = None) -> None:
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"experiment_statistics_{timestamp}.csv"
        
        stats = self.calculate_statistics()

        flattened_stats = []
        for key, metrics in stats.items():
            model, dataset = key.rsplit('_', 1)
            
            row = {
                'model': model,
                'dataset': dataset
            }
            
            for metric_name, metric_stats in metrics.items():
                row[f"{metric_name}_mean"] = metric_stats['mean']
                row[f"{metric_name}_std_err"] = metric_stats['std_err']
                row[f"{metric_name}_count"] = metric_stats['count']
                row[f"{metric_name}_mean_pm_stderr"] = f"{metric_stats['mean']:.4f} ± {metric_stats['std_err']:.4f}"
            
            flattened_stats.append(row)

        if flattened_stats:
            fieldnames = sorted(flattened_stats[0].keys())
            
            with open(filename, 'w', newline='', encoding='utf-8') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
                writer.writerows(flattened_stats)
            
            print(f"Statistics saved to: {filename}")


def main():
    runner = ExperimentRunner()
    
    print("EXPERIMENT CONFIGURATION")
    start_time = time.time()
    runner.run_all_experiments(start_from = 40, end = 45)
    total_time = time.time() - start_time
    
    print("=" * 80)
    print("EXPERIMENT COMPLETION")
    print("=" * 80)
    print(f"Total runtime: {total_time/3600:.2f} hours")

    runner.save_results()
    runner.save_statistics()
    
    print("All experiments completed!")

if __name__ == "__main__":
    main()
