import os
import pickle
from typing import Any

import jax.numpy as jnp
from rich import print
from tqdm import tqdm

from medium_rl.config import Config
from medium_rl.envs.bit_seq import BitSequence
from medium_rl.eval import eval, eval_modes
from medium_rl.init import init_cfg
from medium_rl.policy import make_alg_policy
from medium_rl.train import make_train_step_fn
from medium_rl.utils import save_model


def save(obj: Any, path: str):
    with open(path, "wb") as f:
        return pickle.dump(obj, f)


def run(cfg: Config):
    """INIT"""
    run_state, env, forward, optimizer, buffer, buffer_state, network_cfg = init_cfg(cfg=cfg)
    policy_fn = make_alg_policy(cfg.eps, cfg.alg)
    eval_policy_fn = make_alg_policy(0.0, cfg.alg)

    train_step_fn = make_train_step_fn(env, buffer, forward, policy_fn, optimizer, cfg)

    # Calculate number of iterations based on number of samples to generate
    num_iter = int(cfg.num_gen_samples // cfg.num_envs)
    eval_iter = num_iter // (cfg.num_gen_samples / cfg.samples_per_eval)

    """ MAIN LOOP """
    frames = []
    metrics = []

    # Collect all samples generated during training
    gen_samples = jnp.zeros((cfg.num_gen_samples, env.max_len), dtype=jnp.int8)
    gen_rewards = jnp.zeros((cfg.num_gen_samples,), dtype=jnp.float32)
    if isinstance(env, BitSequence):
        min_dists = jnp.full((env.num_modes,), jnp.inf)

    with tqdm(range(num_iter + 1), dynamic_ncols=True) as pbar:
        for i in pbar:
            if i % eval_iter == 0:
                frames.append(i * cfg.num_envs)
                print(f"Evaluation results after {i * cfg.num_envs} samples:")
                res, rewards, embeds, samples, top_samples = eval(
                    run_state, forward, eval_policy_fn, env, cfg.num_eval_samples, cfg.top_k
                )

                if isinstance(env, BitSequence):
                    res["min_dists"] = min_dists
                else:
                    if cfg.eval_modes:
                        sweep_res, top_modes_res = eval_modes(run_state, forward, env, cfg)
                        res = {**res, **top_modes_res}

                print({k: round(v, 2) if type(v) is float else v for k, v in res.items()})
                print("Top 5 samples:")
                print(env.vectorized_token_idx_to_str(top_samples[:5]))
                print("-----------------------------------------------------------------")
                metrics.append(res)

            run_state, buffer_state, loss_info, new_samples, new_rewards, extra_oracle_info = train_step_fn(
                run_state, buffer_state
            )

            # Update target network
            if cfg.target_update_steps and i % cfg.target_update_steps == 0:
                run_state = run_state.replace(target_params=run_state.params.copy())

            # Store generated samples
            start, end = i * cfg.num_envs, (i + 1) * cfg.num_envs
            gen_samples = gen_samples.at[start:end].set(new_samples.astype(jnp.int8))
            gen_rewards = gen_rewards.at[start:end].set(new_rewards.squeeze(-1))

            if isinstance(env, BitSequence):
                min_dists = jnp.minimum(min_dists, extra_oracle_info.min(axis=0))

            pbar.set_description(f"Loss: {loss_info.mean():.4f}")

    # Final sweep eval
    if not isinstance(env, BitSequence):
        sweep_res, top_modes_res = eval_modes(run_state, forward, env, cfg)
        print(top_modes_res)
        if cfg.save:
            save(sweep_res, os.path.join(cfg.save_path, "sweep.pkl"))
            save(top_modes_res, os.path.join(cfg.save_path, "top_modes.pkl"))

    if cfg.save:
        # Save generated samples/metrics during training
        save(
            {"samples": (gen_samples, gen_rewards), "metrics": (frames, metrics)},
            os.path.join(cfg.save_path, "data.pkl"),
        )

        save_model(
            network_cfg,
            run_state.params,
            os.path.join(cfg.save_path, "model.pkl"),
        )
    return frames, metrics
