# file: user_extensions/experiments/utils.py
import copy
from pathlib import Path
import sys
import time
import subprocess
import yaml

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import torch

from main import run_evaluation, find_last_checkpoint
from prism.core.registry import DATASETS, SYSTEMS
from prism.evaluation.metrics import calculate_probe_gap
from prism.utils.config import AttrDict, process_derived_config


def apply_trial_suggestions(config, trial, search_space):
    trial_config = AttrDict(copy.deepcopy(config.to_dict()))
    for param_path, settings in search_space.items():
        keys = param_path.split('.')
        d = trial_config
        for key in keys[:-1]:
            d = d[key]

        param_name = keys[-1]
        param_type = settings['type']

        if param_path == 'model.latent_space.target_dim':
            latent_dim = trial_config.model.latent_space.latent_dim
            high = min(settings['high'], latent_dim // 2)
            low = min(settings['low'], high)
            if low >= high:
                d[param_name] = low
            else:
                d[param_name] = trial.suggest_int(param_path, low, high, step=settings.get('step', 1))
        elif param_path == 'model.fcn_params.target_channels':
            latent_channels = trial_config.model.fcn_params.latent_channels
            high = min(settings['high'], latent_channels // 2)
            low = min(settings['low'], high)
            if low >= high:
                d[param_name] = low
            else:
                d[param_name] = trial.suggest_int(param_path, low, high, step=settings.get('step', 1))
        elif param_type == 'categorical':
            d[param_name] = trial.suggest_categorical(param_path, settings['choices'])
        elif param_type == 'float':
            d[param_name] = trial.suggest_float(param_path, settings['low'], settings['high'], log=settings.get('log', False))
        elif param_type == 'int':
            d[param_name] = trial.suggest_int(param_path, settings['low'], settings['high'], step=settings.get('step', 1), log=settings.get('log', False))
    return trial_config


@torch.no_grad()
def calculate_hpo_objective(trainer, pl_module):
    pl_module.eval()
    val_loader = trainer.datamodule.val_dataloader()
    device = pl_module.device

    all_z, all_y = [], []

    for batch in val_loader:
        data, target_labels, _ = batch
        data, target_labels = data.to(device), target_labels.to(device)
        z = pl_module.encoder(data)
        all_z.append(z.cpu())
        all_y.append(target_labels.cpu())

    z_full, y_full = torch.cat(all_z), torch.cat(all_y)
    probe_results = calculate_probe_gap(z_full, y_full, pl_module.config)
    probe_gap = probe_results['probe_gap']
    return probe_gap


def hpo_objective(trial, base_config, search_space, logger_info, device_id, resume=False):
    try:
        trial_config_raw = apply_trial_suggestions(base_config, trial, search_space)

        if 'model.architecture.conv.depth' in trial.params:
            m_arch = trial_config_raw.model.architecture.conv
            depth = m_arch.depth
            start_ch = m_arch.start_channels
            repeats = m_arch.block_repeats_per_layer

            encoder_h_dims = [start_ch * (2 ** i) for i in range(depth)]
            decoder_h_dims = list(reversed(encoder_h_dims))
            block_repeats = [repeats] * depth

            m_arch.encoder.h_dims = encoder_h_dims
            m_arch.decoder.h_dims = decoder_h_dims
            m_arch.encoder.block_repeats = block_repeats
            m_arch.decoder.block_repeats = block_repeats

            mlp_depth = m_arch.encoder.mlp_h_units_depth
            mlp_size = m_arch.encoder.mlp_h_units_size
            mlp_units = [mlp_size] * mlp_depth

            m_arch.encoder.mlp_h_units = mlp_units
            m_arch.decoder.mlp_h_units = mlp_units

        if 'discriminator_q.architecture.conv.depth' in trial.params:
            d_arch = trial_config_raw.discriminator_q.architecture.conv
            depth = d_arch.depth
            start_ch = d_arch.start_channels
            repeats = d_arch.block_repeats_per_layer

            encoder_h_dims = [start_ch * (2 ** i) for i in range(depth)]
            block_repeats = [repeats] * depth

            d_arch.encoder.h_dims = encoder_h_dims
            d_arch.encoder.block_repeats = block_repeats

        trial_config = process_derived_config(trial_config_raw)

        datamodule = DATASETS.get(trial_config.data.name)(trial_config)

        system = SYSTEMS.get("PrismSystem")(trial_config)

        logger = TensorBoardLogger(
            save_dir=logger_info['save_dir'],
            name=logger_info['name'],
            version=f"trial_{trial.number}"
        )

        checkpoint_callback = ModelCheckpoint(
            dirpath=Path(logger.log_dir) / "checkpoints",
            save_last=True
        )

        ckpt_path = find_last_checkpoint(Path(logger.log_dir)) if resume else None

        trainer = pl.Trainer(
            max_epochs=trial_config.training.epochs,
            accelerator="gpu",
            devices=[device_id],
            logger=logger,
            callbacks=[checkpoint_callback],
            enable_progress_bar=logger_info.get('progress_bar', True),
            log_every_n_steps=trial_config.evaluation.log_interval,
        )
        trainer.fit(model=system, datamodule=datamodule, ckpt_path=ckpt_path)

        probe_gap = calculate_hpo_objective(trainer, system)
        trial.set_user_attr("probe_gap", probe_gap)
        return probe_gap

    except Exception as e:
        print(f"Trial {trial.number} failed: {e}", file=sys.stderr)
        return -1.0


def run_post_hoc_evaluation(study_dir):
    print("\n" + "=" * 60)
    print("STEP 3: RUNNING POST-HOC EVALUATION")
    print("=" * 60)

    run_dirs = [d for d in study_dir.iterdir() if d.is_dir() and not d.name.startswith('_')]

    for run_dir in run_dirs:
        print(f"\n--- Evaluating run: {run_dir.name} ---")
        try:
            version_dirs = sorted(list(run_dir.glob("**/version_*")))
            if not version_dirs:
                print(f"  Skipping {run_dir.name}, no version directory found.")
                continue

            for version_dir in version_dirs:
                print(f"  > Evaluating {version_dir}")
                run_evaluation(run_dir=version_dir, device='cpu')

        except Exception as e:
            print(f"  Evaluation failed for {run_dir.name}: {e}", file=sys.stderr)


def run_trial(config_dict, max_retries, device_id, resume=False):
    trial_name = config_dict['run']['sweep_name']
    study_dir = Path(config_dict['run']['log_dir'])

    print(f"\n--- Starting Trial: {trial_name} ---")

    trial_log_dir = study_dir / trial_name
    version_dirs = sorted(list(trial_log_dir.glob("version_*")))
    if resume and version_dirs:
        latest_version_dir = version_dirs[-1]
        success_marker = latest_version_dir / ".success"
        if success_marker.exists():
            print(f"Success marker found for '{trial_name}' in {latest_version_dir}. Skipping training.")
            return True

    elif not resume:
        pass

    run_config_dir = study_dir / "_temp_configs" / trial_name.replace("/", "_")
    run_config_dir.mkdir(parents=True, exist_ok=True)
    trial_config_path = run_config_dir / "_trial_config.yaml"
    with open(trial_config_path, 'w') as f:
        yaml.dump(config_dict, f, default_flow_style=False)

    for attempt in range(max_retries + 1):
        try:
            main_script_path = Path(__file__).resolve().parent.parent.parent / "main.py"
            if not main_script_path.exists():
                raise FileNotFoundError(f"Could not find main.py at {main_script_path}")

            command = [
                sys.executable, str(main_script_path), 'train',
                '--config', str(trial_config_path)
            ]
            if device_id is not None:
                # Note: The main.py will need to be updated to handle device_id correctly
                # for direct training calls, or this argument should be removed if PL handles it.
                # Assuming PL handles it via CUDA_VISIBLE_DEVICES set by the orchestrator.
                pass

            if resume:
                command.append('--resume')

            subprocess.run(command, check=True)

            final_logger = TensorBoardLogger(save_dir=study_dir, name=trial_name)
            final_log_dir = Path(final_logger.log_dir)
            (final_log_dir / ".success").touch()

            print(f"--- Trial {trial_name} completed successfully. ---")
            return True

        except Exception as e:
            print(f"--- FAILED: {trial_name} (Attempt {attempt + 1}): {e} ---", file=sys.stderr)
            if attempt < max_retries:
                print("  Retrying after a short delay...")
                time.sleep(5)
            else:
                print(f"  Max retries reached for {trial_name}.", file=sys.stderr)

    return False