import optuna
import torch
import time
import numpy as np
import os
import sys
import tempfile
import shutil
import uuid
import atexit
import warnings
import subprocess
from pathlib import Path
from typing import Dict, Any
from copy import deepcopy

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import lib
import delu

def get_path(relative_path: str) -> Path:
    current_dir = Path(__file__).parent
    if relative_path.startswith('/') or Path(relative_path).is_absolute():
        return Path(relative_path)
    else:
        return current_dir / relative_path

trial_configs = []
trial_stats = []

def sample_parameters(
    trial: optuna.trial.Trial,
    space: Dict[str, Any],
    base_config: Dict[str, Any],
) -> Dict[str, Any]:
    def get_distribution(distribution_name):
        return getattr(trial, f'suggest_{distribution_name}')
    result = {}
    for label, subspace in space.items():
        if isinstance(subspace, dict):
            result[label] = sample_parameters(trial, subspace, base_config)
        else:
            assert isinstance(subspace, list)
            distribution, *args = subspace
            if distribution.startswith('?'):
                default_value = args[0]
                result[label] = (
                    get_distribution(distribution.lstrip('?'))(label, *args[1:])
                    if trial.suggest_categorical(f'optional_{label}', [False, True])
                    else default_value
                )
            elif distribution == '$mlp_d_layers':
                min_n_layers, max_n_layers, d_min, d_max = args
                n_layers = trial.suggest_int('n_layers', min_n_layers, max_n_layers)
                suggest_dim = lambda name: trial.suggest_int(name, d_min, d_max)  # noqa
                d_first = [suggest_dim('d_first')] if n_layers else []
                d_middle = (
                    [suggest_dim('d_middle')] * (n_layers - 2) if n_layers > 2 else []
                )
                d_last = [suggest_dim('d_last')] if n_layers > 1 else []
                result[label] = d_first + d_middle + d_last
            elif distribution == '$choice':
                result[label] = trial.suggest_categorical(label, args[0])

            else:
                if len(args) > 0 and isinstance(args[-1], dict):
                    kwargs = args[-1]
                    args = args[:-1]
                    result[label] = get_distribution(distribution)(label, *args, **kwargs)
                else:
                    result[label] = get_distribution(distribution)(label, *args)
    
    return result

def merge_sampled_parameters(config: Dict[str, Any], sampled_parameters: Dict[str, Any]):
    for k, v in sampled_parameters.items():
        if isinstance(v, dict):
            merge_sampled_parameters(config.setdefault(k, {}), v)
        else:
            config[k] = v

def objective(trial: optuna.trial.Trial) -> float:
    base_config = {k: v for k, v in args.items() if k not in ['optimization', 'program']}
    config = deepcopy(base_config)
    merge_sampled_parameters(
        config, sample_parameters(trial, args['optimization']['space'], config)
    )
    trial_configs.append(config)

    with tempfile.TemporaryDirectory() as dir_:
        dir_ = Path(dir_)
        out = dir_ / f'trial_{trial.number}'
        config_path = out.with_suffix('.toml')
        lib.dump_toml(config, config_path)
        python_exe = sys.executable

        subprocess.run(
            [python_exe, str(program_copy), str(config_path)],
            check=True,
        )
        
        stats = lib.load_json(out / 'stats.json')
        stats['algorithm'] = stats['algorithm'].rsplit('___', 1)[0]
        trial_stats.append(
            {
                **stats,
                'trial_id': trial.number,
                'tuning_time': lib.format_seconds(timer()),
            }
        )
        lib.dump_json(trial_stats, output / 'trial_stats.json', indent=4)
        lib.backup_output(output)
        print(f'Time: {lib.format_seconds(timer())}')

        optimization_target = args.get('optimization_target', 'auto')
        val_metrics = stats['metrics']['val']
        
        if optimization_target == 'auto':
            target_value = val_metrics['accuracy']
            print(f"Auto-selected optimization target: accuracy = {target_value:.4f}")
        elif optimization_target == 'accuracy':
            target_value = val_metrics['accracy']
            print(f"Using Accuracy = {target_value:.4f}")
        elif optimization_target == 'score':
            target_value = val_metrics['score']
            print(f"Using default Score = {target_value:.4f}")
        return target_value

def save_checkpoint(*_, **__):
    torch.save(
        {
            'trial_configs': trial_configs,
            'trial_stats': trial_stats,
            'study': study,
            'stats': stats,
            'timer': timer,
            'random_state': delu.random.get_state(),
        },
        checkpoint_path,
    )

def main():
    global args, output, program_copy, study, stats, timer, checkpoint_path
    
    print("Starting Universal Hyperparameter Tuning")

    overall_start_time = time.time()
    args, output = lib.load_config_with_args()
    lib.set_seeds(args.get('tuning', {}).get('seed', 42))
   
    program = get_path(args['program'])
    program_copy = program.with_name(
        program.stem + '___' + str(uuid.uuid4()).replace('-', '') + program.suffix
    )
    shutil.copyfile(program, program_copy)
    atexit.register(lambda: program_copy.unlink())

    checkpoint_path = output / 'checkpoint.pt'

    if checkpoint_path.exists():
        checkpoint = torch.load(checkpoint_path)
        trial_configs.extend(checkpoint['trial_configs'])
        trial_stats.extend(checkpoint['trial_stats'])
        study = checkpoint['study']
        stats = checkpoint['stats']
        timer = checkpoint['timer']
        delu.random.set_state(checkpoint['random_state'])
        
        if 'n_trials' in args['optimization']['options']:
            args['optimization']['options']['n_trials'] -= len(study.trials)
        if 'timeout' in args['optimization']['options']:
            args['optimization']['options']['timeout'] -= timer()
        stats.setdefault('continuations', []).append(len(study.trials))
        print(f'Loading checkpoint ({len(study.trials)})')
    else:
        stats = lib.load_json(output / 'stats.json')
        timer = delu.Timer()
        study = optuna.create_study(
            direction='maximize',
            sampler=optuna.samplers.TPESampler(**args['optimization']['sampler']),
        )

    timer.run()

    warnings.filterwarnings('ignore', category=optuna.exceptions.ExperimentalWarning)   
    print(f"Running optimization with {args['optimization']['options'].get('n_trials', 'unlimited')} trials")
    print(f"Program: {program}")

    optimization_options = args['optimization']['options'].copy()
    show_progress_bar = optimization_options.pop('show_progress_bar', True)
    
    study.optimize(
        objective,
        **optimization_options,
        callbacks=[save_checkpoint],
        show_progress_bar=show_progress_bar,
    )

    best_trial_id = study.best_trial.number
    lib.dump_toml(trial_configs[best_trial_id], output / 'best.toml')
    stats['best_stats'] = trial_stats[best_trial_id]
    stats['time'] = lib.format_seconds(timer())
    lib.dump_stats(stats, output, True)
    lib.backup_output(output)
    
    overall_end_time = time.time()
    print(f"\n{'='*60}")
    print(f"OPTIMIZATION COMPLETED")
    print(f"Total time: {lib.format_seconds(overall_end_time - overall_start_time)}")
    print(f"Best trial: {study.best_trial.number}")
    print(f"Best value: {study.best_trial.value:.6f}")
    print(f"Best parameters: {study.best_trial.params}")
    print(f"Results saved to: {output}")
    print(f"{'='*60}")

if __name__ == "__main__":
    main()
