import jax
import tqdm
import json
import numpy as np
import jax.numpy as jnp
from pathlib import Path
from typing import Any, Dict, Callable
from modules.optimization.objectives import make_mse_objective, make_ot_objective
from pipeline.bookkeeping.saving import save_checkpoint
from pipeline.training.runtime_resolve import Experiment, resolve_runtime, Runtime
from pipeline.training.train_configs import TrainConfig
from pipeline.dataloader.transforms import compose, add_noise, crop, normalize
from pipeline.dataloader.jax_transforms import add_noise_jax, crop_jax, normalize_jax
from pipeline.dataloader.dataset import DynamicalSystemDataset
from pipeline.dataloader.batching import batch_iterator
from pipeline.bookkeeping.logging import init_logging, log_epoch
from pipeline.training.state_init import init_state, TrainState
from pipeline.training.train_helper import EpochFlags, make_train_step, make_val_step
from pipeline.eval.metrics import compute_trajectory_errors, compute_summary_errors, get_histogram

def init_rng(seed: int) -> jax.Array:
    return jax.random.PRNGKey(seed)

def load_datasets(experiment: Experiment) -> tuple[DynamicalSystemDataset, DynamicalSystemDataset]:
    train_dataset = DynamicalSystemDataset(
        data_path=experiment.train_data_path,
    )

    val_dataset = DynamicalSystemDataset(
        data_path=experiment.val_data_path,
    )
    return train_dataset, val_dataset


def compute_epoch_flags(epoch: int, train_config: TrainConfig) -> EpochFlags:
    distance_type = train_config.distance_config.get("type", "no_ot")
    use_ot = (
        epoch >= train_config.ot_warm_up
        and distance_type != "no_ot"
        and train_config.lambda_ot != 0.0
    )
    lambda_ot = train_config.lambda_ot if use_ot else 0.0
    flags = {   
        "use_ot": use_ot,
        "lambda_ot": lambda_ot,
    }

    epoch_flags = EpochFlags(**flags)

    return epoch_flags

def run_train_epoch(
    state: TrainState,
    dataset: DynamicalSystemDataset,
    epoch: int,
    rng: jax.Array,
    data_seed: int,
    train_config: TrainConfig,
    runtime: Runtime,
    train_step: Callable
) -> tuple[TrainState, dict[str, float]]:
    
    epoch_flags = compute_epoch_flags(epoch, train_config)

    num_samples = len(dataset)
    num_batches = num_samples // train_config.batch_size

    indices = jnp.arange(num_samples)
    indices = jax.random.permutation(jax.random.PRNGKey(data_seed + epoch), indices)
    batch_indices = indices[: num_batches * train_config.batch_size].reshape((num_batches, train_config.batch_size))

    all_traj = jax.device_put(dataset.traj, runtime.device)

    def train_on_batch(carry: tuple[TrainState, jax.Array], batch_idx: jax.Array) -> tuple[tuple[TrainState, jax.Array], dict[str, jax.Array]]:
        state, rng = carry
        rng, step_rng = jax.random.split(rng)
        traj_batch = all_traj[batch_idx]

        traj_batch = add_noise_jax(noise_level=train_config.noise_level)(traj_batch, step_rng)
        # traj_batch = crop_jax(window_size=train_config.crop_window_size)(traj_batch, step_rng)
        traj_batch = normalize_jax()(traj_batch, step_rng)

        state, step_metrics = train_step(
            state=state,
            batch=traj_batch,
            rng=step_rng,
            epoch_flags=epoch_flags,
        )
        return (state, rng), step_metrics

    (state, rng), all_metrics = jax.lax.scan(
        train_on_batch,
        (state, rng),
        batch_indices,
    )

    metric_means = {
        k: jnp.mean(v) for k, v in all_metrics.items()
    }
    epoch_metrics = jax.tree.map(float, metric_means)
    return state, epoch_metrics

def run_val_epoch(
    state: TrainState,
    dataset: DynamicalSystemDataset,
    epoch: int,
    rng: jax.Array,
    data_seed: int,
    train_config: TrainConfig,
    runtime: Runtime,
    val_step: Callable,
) -> dict[str, dict[str, Any]]:

    epoch_flags = compute_epoch_flags(epoch, train_config)

    u_true_all = []
    u_hat_all = []
    s_true_all = []
    s_hat_all = []

    num_samples = len(dataset)

    if num_samples < train_config.batch_size:
        batch_indices = np.arange(num_samples).reshape(1, -1)
    else:
        num_batches = num_samples // train_config.batch_size
        indices = np.arange(num_samples)
        batch_indices = indices[:num_batches * train_config.batch_size].reshape((num_batches, train_config.batch_size))

    transforms = compose([
        add_noise(noise_level=train_config.noise_level),
        crop(window_size=train_config.crop_window_size),
        normalize(),
    ])

    epoch_rng = np.random.default_rng(data_seed + epoch)

    for batch_idx in batch_indices:
        rng, step_rng = jax.random.split(rng)
        traj_batch = dataset.get_batch(batch_idx)
        sample_rng = jax.random.PRNGKey(np.uint32(epoch_rng.integers(0, 2**32 - 1)))
        transformed_batch = []
        for i in range(len(batch_idx)):
            transformed = transforms(traj_batch[i], sample_rng)
            transformed_batch.append(transformed)
        traj_batch = np.stack(transformed_batch, axis=0)
        traj_batch = jax.device_put(traj_batch, runtime.device)
        outputs = val_step(
            state=state,
            batch=traj_batch,
            rng=step_rng,
            epoch_flags=epoch_flags,
        )

        u_true_all.append(outputs["u_true"])
        u_hat_all.append(outputs["u_hat"])
        s_true_all.append(outputs["s_true"])
        s_hat_all.append(outputs["s_hat"])

    u_true = np.asarray(jnp.concatenate(u_true_all, axis=0))
    u_hat = np.asarray(jnp.concatenate(u_hat_all, axis=0))
    s_true = np.asarray(jnp.concatenate(s_true_all, axis=0))
    s_hat = np.asarray(jnp.concatenate(s_hat_all, axis=0))

    traj_metrics = compute_trajectory_errors(u_true, u_hat)
    summary_metrics = compute_summary_errors(s_true, s_hat)
    
    k = int(np.ceil(np.sqrt(np.max([u_true.shape[1], u_hat.shape[1]]))))
    
    H_u_true = get_histogram(u_true, num_bins=k, density=True)
    H_u_hat  = get_histogram(u_hat,  num_bins=k, density=True)
    
    hist_true = H_u_true[0]
    hist_hat = H_u_hat[0]

    return    {
        "metrics": {
            "trajectory": traj_metrics,
            "summary": summary_metrics,
        },
        "diagnostics": {
            "u_true": u_true[0],   # shape (T, d)
            "u_hat":  u_hat[0],    # shape (T, d)
            "s_true": s_true[0],   # shape (D,) or (T, D)
            "s_hat":  s_hat[0],
            "hist_true": hist_true,
            "hist_hat":  hist_hat,
        }
        }


def train_exp(
    train_config: TrainConfig,
    experiment: Experiment,
    use_wandb: bool = True,
) -> Path:
    with open(experiment.experiment_dir / "train_config.json", "w") as f:
        json.dump(train_config.__dict__, f, indent=2)
    with open(experiment.experiment_dir / "exp_config.json", "w") as f:
        serialized_experiment = (experiment.__dict__.copy())
        serialized_experiment.update({
            "experiment_dir": str(experiment.experiment_dir),})

        json.dump(serialized_experiment, f, indent=2)
    runtime = resolve_runtime(train_config)

    rng = init_rng(experiment.seed)
    data_seed = experiment.seed

    logger = init_logging(
        experiment=experiment,
        train_config=train_config,
    )

    train_dataset, val_dataset = load_datasets(experiment)

    mse_objective = make_mse_objective(
        g=runtime.rollout_mse,
    )
    ot_objective = make_ot_objective(
        g=runtime.rollout_ot,
        f=runtime.summary_apply,
        D=runtime.distance_apply,
        ot_horizon=train_config.ot_horizon,
    )

    critic_weight_clip = None
    if train_config.distance_config.get("type") == "wgan":
        critic_weight_clip = train_config.distance_config.get("weight_clip", 0.01)

    train_step = make_train_step(
        mse_objective=mse_objective,
        ot_objective=ot_objective,
        emulator_optimizer=runtime.emulator_optimizer,
        summary_optimizer=runtime.summary_optimizer,
        critic_optimizer=runtime.critic_optimizer,
        summary_has_params=True if train_config.summary_config["type"] == 'mlp' or train_config.summary_config["type"] == 'linear' else False,
        critic_is_adversary=True if train_config.distance_config["type"] == 'wgan' else False,
        critic_weight_clip=critic_weight_clip,
        maximizer_steps=train_config.adversarial_steps,
        summary_adversarial=train_config.distance_config["type"] != "no_ot",
        summary_apply=runtime.summary_apply,
    )

    # train_step = jax.jit(
    #     train_step
    # )
    
    val_step = make_val_step(
        rollout_mse=runtime.rollout_mse,
        rollout_ot=runtime.rollout_ot,
        summary_apply=runtime.summary_apply,
    )

    state = init_state(
        train_config=train_config,
        runtime=runtime,
        rng=rng
    )

    ckpt_dir = experiment.experiment_dir / "checkpoints"
    ckpt_every = getattr(train_config, "checkpoint_every", 1)

    for epoch in tqdm.tqdm(range(train_config.epochs)):
        rng, epoch_rng = jax.random.split(rng)

        state, train_metrics = run_train_epoch(
            state=state,
            dataset=train_dataset,
            epoch=epoch,
            rng=epoch_rng,
            data_seed=data_seed,
            train_config=train_config,
            runtime=runtime,
            train_step=train_step
        )
        eval_metrics = run_val_epoch(
            state=state,
            dataset=val_dataset,
            epoch=epoch,
            rng=epoch_rng,
            data_seed=data_seed,
            train_config=train_config,
            runtime=runtime,
            val_step=val_step,
        )

        log_epoch(
            ckpt_dir=ckpt_dir,
            epoch=epoch,
            train_metrics=train_metrics,
            eval_metrics=eval_metrics,
            use_wandb=train_config.use_wandb,
        )

        if (epoch + 1) % ckpt_every == 0:
            save_checkpoint(
                path=ckpt_dir / f"epoch_{epoch:04d}",
                state=state,
                train_config=train_config,
                epoch=epoch,
                experiment=experiment
            )

    save_checkpoint(
        path=ckpt_dir / "final",
        state=state,
        train_config=train_config,
        epoch=train_config.epochs,
        experiment=experiment
    )
    return experiment.experiment_dir
