import argparse
import numpy as np
import time
import json
import itertools
from typing import Dict, List, Any

from datasets import get_dataset
from train_eval import run
from rbm import get_model


def load_config(config_path: str) -> Dict[str, Any]:
    """Load configuration from JSON file."""
    with open(config_path, 'r') as f:
        return json.load(f)


def generate_parameter_grid(config: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    Generate all combinations of parameters for sweep.
    Parameters with lists will be swept over, others remain fixed.
    """
    sweep_params = {}
    fixed_params = {}

    for key, value in config.items():
        if isinstance(value, list) and len(value) > 0:
            sweep_params[key] = value
        else:
            fixed_params[key] = value

    if not sweep_params:
        return [config]  # no sweeps

    keys = sorted(sweep_params.keys())
    values = [sweep_params[k] for k in keys]
    combinations = list(itertools.product(*values))

    configs = []
    for combo in combinations:
        cfg = fixed_params.copy()
        for k, v in zip(keys, combo):
            cfg[k] = v
        configs.append(cfg)

    return configs


def generate_logger(config: Dict[str, Any]) -> str:
    """Generate logger name from config."""
    logger = config.get('logger', 'experiment')
    model = config.get('model', 'default')
    dataset = config.get('dataset', 'unknown')
    split = config.get('split', 'public')
    k = config.get('k', 1)
    residual = config.get('residual', None)
    alpha = config.get('alpha', None)

    return f"{logger}_{model}_{dataset}_split-{split}_k-{k}_residual-{residual}_alpha-{alpha}"


if __name__ == '__main__':
    print(f"start main experiment script")
    parser = argparse.ArgumentParser(
        description='Run experiments with JSON configuration and parameter sweeps'
    )
    parser.add_argument('--config', type=str, required=True,
                        help='Path to JSON configuration file')
    args = parser.parse_args()

    print("=" * 60)
    print("Starting experiment runner")
    print("=" * 60)

    # Load configuration
    print(f"\nLoading configuration from: {args.config}")
    config = load_config(args.config)

    # Generate all param sweep combinations
    param_configs = generate_parameter_grid(config)
    print(f"Generated {len(param_configs)} configuration(s) for execution")

    # --------------------------------------------------
    # LOAD ALL REQUIRED DATASETS ONCE
    # --------------------------------------------------
    dataset_names = config.get("dataset", ["cora"])
    if not isinstance(dataset_names, list):
        dataset_names = [dataset_names]

    normalize_features = config.get("normalize_features", True)
    split = config.get("split", "public")

    print("\n============================================")
    print(" Loading required datasets ONCE each ")
    print("============================================\n")

    loaded_datasets = {}

    for ds_name in dataset_names:
        print(f"Loading dataset '{ds_name}' (split={split})...", end='', flush=True)
        start_t = time.time()

        loaded_datasets[ds_name] = get_dataset(
            name=ds_name,
            normalize_features=normalize_features,
            split=split
        )

        print(f" Done ({time.time() - start_t:.2f}s)")

    # --------------------------------------------------
    # Run all experiment configurations
    # --------------------------------------------------
    import torch
    import psutil, os

    def get_cpu_ram():
        process = psutil.Process(os.getpid())
        return process.memory_info().rss / 1024 ** 2  # MB

    total_experiment_start = time.time()

    for exp_idx, exp_config in enumerate(param_configs, 1):

        print(f"\n{'=' * 60}")
        print(f"Experiment {exp_idx}/{len(param_configs)}")
        print(f"{'=' * 60}")

        # Print current configuration
        print("\nConfiguration:")
        for key, value in sorted(exp_config.items()):
            if key not in ['logger', 'md_file', 'csv_file']:
                print(f"   {key}: {value}")

        # --------------------------------------------
        # RETRIEVE PRELOADED DATASET
        # --------------------------------------------
        dataset_name = exp_config.get("dataset", dataset_names[0])
        dataset = loaded_datasets[dataset_name]

        # --------------------------------------------
        # INITIALIZE MODEL
        # --------------------------------------------
        print(f"\nInitializing model...", end='', flush=True)
        start_time = time.time()

        model_name = exp_config.get('model', 'RBMConvNet')
        model = get_model(
            model_name,
            dataset,
            k=exp_config.get('k', 1),
            alpha=exp_config.get('alpha', None),
            residual=exp_config.get('residual', None),
            num_layers=exp_config.get('num_layers', 3),
            forward_sampling=exp_config.get('forward_sampling', 'gumbel_softmax'),
            backward_sampling=exp_config.get('backward_sampling', 'sigmoid')
        )

        init_time = time.time() - start_time
        print(f" Done ({init_time:.2f}s)")

        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"   Total parameters: {total_params:,}")
        print(f"   Trainable parameters: {trainable_params:,}")

        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()

        log_dict = {
            'model': model_name,
            'dataset': dataset_name,
            'split': split,
            'optimizer': exp_config.get('optimizer', 'Adam'),
            'epochs': exp_config.get('epochs', 200),
            'runs': exp_config.get('runs', 10),
            'lr': exp_config.get('lr', 0.01),
            'k': exp_config.get('k', 1),
            'residual': exp_config.get('residual', None),
            'forward_sampling': exp_config.get('forward_sampling', 'gumbel_softmax'),
            'backward_sampling': exp_config.get('backward_sampling', 'sigmoid'),
            'num_layers': exp_config.get('num_layers', 2),
            'lambda_rbm': exp_config.get('lambda_rbm', 0.0),
            'loss_type': exp_config.get('loss_type', 'cd'),
            'md_file': exp_config.get('md_file', 'benchmark_results.md'),
            'csv_file': exp_config.get('csv_file', 'benchmark_results.csv')
        }

        kwargs = {
            'dataset': dataset,
            'model': model,
            'split': split,
            'str_optimizer': exp_config.get('optimizer', 'Adam'),
            'str_preconditioner': exp_config.get('preconditioner', None),
            'runs': exp_config.get('runs', 10),
            'epochs': exp_config.get('epochs', 200),
            'lr': exp_config.get('lr', 0.01),
            'loss': exp_config.get('loss', 'cross_entropy'),
            'loss_type': exp_config.get('loss_type', 'cd'),
            'lambda_rbm': exp_config.get('lambda_rbm', 0.0),
            'weight_decay': exp_config.get('weight_decay', 0.0005),
            'early_stopping': exp_config.get('early_stopping', 0),
            'logger': generate_logger(exp_config),
            'momentum': exp_config.get('momentum', 0.9),
            'eps': exp_config.get('eps', 0.01),
            'update_freq': exp_config.get('update_freq', 50),
            'gamma': exp_config.get('gamma', None),
            'alpha': exp_config.get('alpha', None),
            'log_dict': log_dict,
            'visualize_classes': False,
            'hyperparam': None
        }

        print("\nStarting training...")
        print("=" * 60 + "\n")

        exp_start = time.time()
        run(**kwargs)
        exp_time = time.time() - exp_start

        gpu_peak = torch.cuda.max_memory_allocated() / 1024 ** 2 if torch.cuda.is_available() else 0
        cpu_peak = get_cpu_ram()

        num_samples = len(dataset)
        throughput = num_samples * exp_config.get('epochs', 200) / exp_time

        print(f"\nExperiment {exp_idx} completed in {exp_time:.2f}s")
        print(f"   Peak GPU VRAM usage: {gpu_peak:.2f} MB")
        print(f"   Peak CPU RAM usage: {cpu_peak:.2f} MB")
        print(f"   Throughput: {throughput:.2f} samples/sec\n")

    total_experiment_time = time.time() - total_experiment_start
    print(f"\nAll experiments completed in {total_experiment_time:.2f}s")
    print("=" * 60)
