from turtle import distance
import jax
import json
import optax
import jax.numpy as jnp
from pathlib import Path
from typing import Any, NamedTuple, Callable
from modules.registries import EMULATOR_REGISTRY, SUMMARY_REGISTRY, DISTANCE_REGISTRY, OPTIMIZER_REGISTRY
from pipeline.training.get_experiment_info import Experiment
from pipeline.training.rollout import anchored_rollout, full_rollout
from pipeline.training.train_configs import TrainConfig, ExperimentInputs
from pipeline.training.get_experiment_info import (
    get_problem_name,
    get_exp_name,
    get_run_name,
    get_exp_path,
    get_timestamp,
)

class Runtime(NamedTuple):
    backend: str
    device: Any
    dtype: Any

    # fθ​[·]
    emulator_init: Callable
    emulator_apply: Callable

    # gϕ​[·]
    summary_init: Callable
    summary_apply: Callable

    # Dω[·, ·]
    distance_init: Callable
    distance_apply: Callable

    rollout_mse: Callable
    rollout_ot: Callable

    emulator_optimizer: optax.GradientTransformation
    summary_optimizer: optax.GradientTransformation
    critic_optimizer: optax.GradientTransformation

def _get_state_index(train_config: TrainConfig) -> float:
    summary_type = train_config.summary_config["type"]
    summary_config_path = Path(f"configs/summary/{summary_type}_config.json")
    summary_config: dict[str, Any] = {}
    if summary_config_path.exists():
        with open(summary_config_path, "r") as f:
            summary_config = json.load(f)

    state_index = summary_config['state_index']
    return state_index

def _get_output_dim(train_config: TrainConfig) -> float:
    summary_type = train_config.summary_config["type"]
    summary_config_path = Path(f"configs/summary/{summary_type}_config.json")
    summary_config: dict[str, Any] = {}
    if summary_config_path.exists():
        with open(summary_config_path, "r") as f:
            summary_config = json.load(f)
    output_dim = summary_config['output_dim']
    return output_dim

def resolve_experiment(train_config: TrainConfig, exp_inputs: ExperimentInputs) -> Experiment:
    """
    Resolve experiment identity, filesystem layout, and logging metadata.
    """
    problem_name = get_problem_name(exp_inputs.train_data_path)
    dataset_name = Path(exp_inputs.train_data_path).parent.name

    # ---- experiment naming ----
    ot_type = (
        train_config.summary_config["type"]
        if train_config.distance_config["type"] != "no_ot"
        else "no_ot"
    )
    if ot_type != 'no_ot' and train_config.summary_config["type"] == "projection":
        state_index = _get_state_index(train_config)
        ot_type = f"projection_on_state{state_index + 1}"
    
    if ot_type != 'no_ot' and train_config.summary_config["type"] == "mlp":
        output_dim = _get_output_dim(train_config)
        ot_type = f"learnable_mlp_dim{output_dim}"
    
    if ot_type != 'no_ot' and train_config.summary_config["type"] == "linear":
        output_dim = _get_output_dim(train_config)
        ot_type = f"learnable_linear_dim{output_dim}"

    clean_or_noisy = (
        f"noisy_{train_config.noise_level:.1f}"
        if train_config.noise_level > 0
        else "clean"
    )
    sinkhorn_or_wgan = (
        train_config.distance_config["type"]
        if train_config.distance_config["type"] != "no_ot"
        else ""
    )

    experiment_name = get_exp_name(
        problem_name,
        ot_type,
        clean_or_noisy,
        sinkhorn_or_wgan,
        "anchored_for_ot" if train_config.ot_rollout=="anchored" else "full_rollout_on_ot",
    )

    run_name = get_run_name(
        train_config.epochs,
        train_config.lambda_ot,
        train_config.noise_level,
        exp_inputs.seed
    )

    # ---- timestamp & filesystem ----

    timestamp = get_timestamp()
    experiment_dir = get_exp_path(experiment_name, run_name, timestamp)
    experiment_dir.mkdir(parents=True, exist_ok=True)

    wandb_project = "Learning to Emulate Chaos: Adversarial Optimal Transport Regularization"
    wandb_group = experiment_name
    wandb_run_name = experiment_dir.name

    train_config: dict[str, Any] = {
        "experiment": str(experiment_dir),
        "dataset": dataset_name,
        "timestamp": timestamp,
        "seed": exp_inputs.seed,

        "epochs": train_config.epochs,
        "batch_size": train_config.batch_size,
        "anchor_after": train_config.anchor_after,
        "adversarial_steps": train_config.adversarial_steps,

        "emulator_lr": train_config.emulator_lr,
        "emulator_config": train_config.emulator_config, # This is just the dict containg 'type'. For all configs, add a small function that fetches the configs from the .json file (which is always in the same path)
        "emulator_optimizer_type": train_config.emulator_optimizer_type,

        "summary_config": train_config.summary_config,
        "summary_optimizer_type": train_config.summary_optimizer_type,
        "summary_lr": train_config.summary_lr,

        "distance_config": train_config.distance_config,
        "critic_optimizer_type": train_config.critic_optimizer_type,
        "critic_lr": train_config.critic_lr,

        "lambda_ot": train_config.lambda_ot,
        "noise_level": train_config.noise_level,
        "crop_window_size": train_config.crop_window_size,
    }

    return Experiment(
        experiment_name=experiment_name,
        run_name=experiment_dir.name,
        timestamp=timestamp,
        experiment_dir=experiment_dir,
        train_data_path=exp_inputs.train_data_path,
        val_data_path=exp_inputs.val_data_path,
        test_data_path=exp_inputs.test_data_path,
        problem_name=problem_name,
        dataset_name=dataset_name,
        seed=exp_inputs.seed,
        wandb_project=wandb_project,
        wandb_run_name=wandb_run_name,
        wandb_group=wandb_group,
        train_config=train_config,
    )

def resolve_runtime(train_config: TrainConfig) -> Runtime:
    backend = jax.default_backend()
    device = jax.devices(backend)[0]

    dtype = getattr(jnp, train_config.precision)

    emulator_config_path = Path(f"configs/emulator/{train_config.emulator_config['type']}_config.json")
    with open(emulator_config_path, 'r') as f:
        emulator_config = json.load(f)
    
    emulator_factory = EMULATOR_REGISTRY[train_config.emulator_config['type']]
    emulator_init, emulator_apply = emulator_factory(emulator_config)

    summary_config_path = Path(f"configs/summary/{train_config.summary_config['type']}_config.json")
    with open(summary_config_path, 'r') as f:
        summary_config = json.load(f)

    summary_factory = SUMMARY_REGISTRY[train_config.summary_config['type']]

    summary_def = summary_factory(summary_config)
    summary_init = summary_def["init"]
    summary_apply = summary_def["apply"]

    distance_config_path = Path(f"configs/distance/{train_config.distance_config['type']}_config.json")
    with open(distance_config_path, 'r') as f:
        distance_config = json.load(f)

    if train_config.distance_config['type'] == 'wgan':
        summary_type = train_config.summary_config['type']
        if summary_type in {'mlp', 'linear'}:
            distance_config['input_dim'] = summary_config['output_dim']
        elif summary_type == 'projection':
            distance_config['input_dim'] = 1  # projection on single state variable
        else:
            distance_config['input_dim'] = emulator_config['input_dim']

    distance_factory = DISTANCE_REGISTRY[train_config.distance_config['type']]
    distance_def = distance_factory(distance_config)

    distance_init = distance_def["init"]
    distance_apply = distance_def["apply"]

    emulator_optimizer = OPTIMIZER_REGISTRY[
        train_config.emulator_optimizer_type
    ](train_config.emulator_lr)

    summary_optimizer = OPTIMIZER_REGISTRY[
        train_config.summary_optimizer_type
    ](train_config.summary_lr)

    critic_optimizer = OPTIMIZER_REGISTRY[
        train_config.critic_optimizer_type
    ](train_config.critic_lr)

    rollout_mse = anchored_rollout(
        stepper_apply=emulator_apply,
        anchor_after=train_config.anchor_after,
    )
    rollout_ot = anchored_rollout(
        stepper_apply=emulator_apply,
        anchor_after=train_config.anchor_after,
    ) if train_config.ot_rollout != "full" else full_rollout(
        stepper_apply=emulator_apply,
    )

    return Runtime(
        backend=backend,
        device=device,
        dtype=dtype,

        emulator_init=emulator_init,
        emulator_apply=emulator_apply,

        summary_init=summary_init,
        summary_apply=summary_apply,

        distance_init=distance_init,
        distance_apply=distance_apply,

        emulator_optimizer=emulator_optimizer,
        summary_optimizer=summary_optimizer,
        rollout_mse=rollout_mse,
        rollout_ot=rollout_ot,

        critic_optimizer=critic_optimizer,

    )
