import os
import time
import datetime
from copy import deepcopy

import jax
import jax.numpy as jnp
import pickle

from clu import parameter_overview

from algorithms.adqn.agents.q_agent import Q_Agent
from algorithms.utils.replay_buffer import ReplayBuffer
from algorithms.utils.env_functions import simulate_n_episodes
from algorithms.adqn.utils.utils import create_train_step


def adqn(
    env,
    t_key,
    hidden_layers,
    criterion,
    t_seed,
    m_seeds,
    gamma,
    batch_size,
    n_initial_samples,
    eps_final,
    eps_final_steps,
    replay_buffer_size,
    n_epochs,
    min_steps_per_epoch,
    target_update_frequency,
    val_episodes,
    learning_rate,
    **kwargs,
):
    start_time = time.time()
    sample = env.observation_space.sample()

    if "save_path" in kwargs.keys():
        save_path = kwargs["save_path"]
    else:
        save_path = None

    if criterion == "eps_min":
        eps_reg_args = kwargs["eps_regressor"]
        eps_reg = 1.0

    # copy environment for testing
    test_env = deepcopy(env)
    t_key, sim_key = jax.random.split(t_key)

    # add output layer to hidden layers
    if not isinstance(hidden_layers[0], list):
        hidden_layers = [hidden_layers]
    models_architectures = [layer + [env.action_space.n] for layer in hidden_layers]

    # create models, train step functions and train states
    models = []
    cur_index = 0
    train_states = []
    train_step_fns = []
    for i, layers in enumerate(models_architectures):
        m_key = jax.random.PRNGKey(m_seeds[i])
        m_key, m_subkey = jax.random.split(m_key)
        models.append(Q_Agent(layers))
        fn, train_state = create_train_step(m_subkey, models[-1], learning_rate, sample)
        train_step_fns.append(fn)
        train_states.append(train_state)

    # print parameter overview
    for i in range(len(models)):
        print(f"Agent {i}")
        print(parameter_overview.get_parameter_overview(train_states[i].params))

    # fix target params
    adqn_target_params = train_states[cur_index].params

    # fill replay memory
    replay_buffer = ReplayBuffer(replay_buffer_size, batch_size, sample)
    t_key, subkey = jax.random.split(t_key)
    env.collect_random_samples(subkey, replay_buffer, n_initial_samples)

    total_steps = 0
    target_updates = 0
    eps = 1.0
    running_loss = [[] for _ in range(len(models))]
    history = {"losses": [], "returns": [], "selected_index": [], "avg_epoch_ret": []}

    for ep in range(n_epochs):
        print(f"Starting epoch {ep}")
        state = env.reset()
        ep_steps = 0
        ep_return = 0
        n_episodes = 0
        episode_steps = 0
        terminated = truncated = False

        # sample from environment
        while ep_steps < min_steps_per_epoch or not (terminated or truncated):
            if terminated or truncated:
                state = env.reset()
                episode_steps = 0
                n_episodes += 1

            t_key, subkey = jax.random.split(t_key)
            if jax.random.uniform(subkey) < eps:
                t_key, subkey = jax.random.split(t_key)
                action = jax.random.choice(subkey, env.n_actions)
            else:
                action_index = cur_index

                if criterion == "eps_min":
                    t_key, subkey = jax.random.split(t_key)
                    if jax.random.uniform(subkey) < eps_reg:
                        action_index = jax.random.choice(subkey, len(models))
                action = jnp.argmax(
                    models[action_index].__apply__(
                        train_states[action_index].params, state
                    )
                )
            next_state, reward, terminated, truncated = env.step(action)
            ep_return += reward
            replay_buffer.add(state, action, next_state, reward, terminated)
            state = next_state

            total_steps += 1
            ep_steps += 1
            episode_steps += 1

            # linear decay epsilon
            if total_steps <= eps_final_steps:
                eps = 1 - (total_steps / eps_final_steps) * (1 - eps_final)

            if (
                criterion == "eps_min"
                and total_steps <= eps_reg_args["eps_reg_final_steps"]
            ):
                eps_reg = 1 - (total_steps / eps_reg_args["eps_reg_final_steps"]) * (
                    1 - eps_reg_args["eps_reg_final"]
                )

            # sample from replay buffer
            t_key, subkey = jax.random.split(t_key)
            states, actions, next_states, rewards, absorbings = replay_buffer.sample(
                subkey
            )

            # compute targets with current target network
            targets = rewards + gamma * (1.0 - absorbings) * jnp.max(
                models[cur_index].__apply__(adqn_target_params, next_states),
                axis=1,
            )

            # gradient step for all regressors
            for i in range(len(models)):
                train_states[i], loss = train_step_fns[i](
                    train_states[i], states, actions, targets
                )
                running_loss[i].append(loss)

            # update target params
            if total_steps % target_update_frequency == 0:
                # compute losses
                losses = jnp.mean(jnp.array(running_loss), axis=1)
                history["losses"].append(losses)
                running_loss = [[] for _ in range(len(models))]

                # select current agent
                if criterion == "min":
                    cur_index = jnp.argmin(losses)
                elif criterion == "max":
                    cur_index = jnp.argmax(losses)
                elif criterion == "random":
                    t_key, subkey = jax.random.split(t_key)
                    cur_index = jax.random.choice(subkey, len(models))
                elif criterion == "eps_min":
                    cur_index = jnp.argmin(losses)
                else:
                    raise ValueError(f"Criterion {criterion} not supported")
                history["selected_index"].append(cur_index)

                # update target params
                adqn_target_params = train_states[cur_index].params
                target_updates += 1

                # simulate episodes for evaluation
                if len(models) > 1 and val_episodes > 0:
                    sim_start_time = time.time()
                    reg_returns = []
                    sim_key, sim_subkey = jax.random.split(sim_key)
                    for i in range(len(models)):
                        returns = simulate_n_episodes(
                            test_env,
                            train_states[i].params,
                            models[i].__apply__,
                            sim_subkey,
                            val_episodes,
                            eps_final,
                        )
                        reg_returns.append(returns)
                    history["returns"].append(jnp.array(reg_returns))
                    sim_elapsed_time = time.time() - sim_start_time

                # print intermediate summary
                intermediate_summary = f"   Total steps {total_steps}"
                intermediate_summary += f" | Losses {[round(float(loss), 3) for loss in history['losses'][-1]]}"
                if len(models) > 1:
                    intermediate_summary += f" | Cur. agent {cur_index}"
                    if val_episodes > 0:
                        intermediate_summary += f" | Avg Returns {[round(float(avg_ret), 1) for avg_ret in jnp.mean(jnp.array(reg_returns), axis=1)]}"
                        intermediate_summary += (
                            f" | Sim. time {round(float(sim_elapsed_time), 1)} s"
                        )
                print(intermediate_summary)

        # termination of last episode not counted in while loop
        n_episodes += 1
        history["avg_epoch_ret"].append(ep_return / n_episodes)

        elapsed_time = time.time() - start_time
        elapsed_hms = str(datetime.timedelta(seconds=int(elapsed_time)))
        summary = f"Summary Epoch {ep}"
        summary += f" | Avg Epoch Ret {ep_return / n_episodes}"
        summary += f" | N Episodes {n_episodes}"
        summary += f" | total steps {total_steps}"
        summary += f" | epoch steps {ep_steps}"

        print(summary)

        # save intermediate results
        if save_path is not None:
            for k, v in history.items():
                path = os.path.join(save_path, k)
                os.makedirs(path, exist_ok=True)
                jnp.save(os.path.join(path, f"{t_seed}_{k}"), v)
            path = os.path.join(save_path, "models")
            os.makedirs(path, exist_ok=True)
            for i in range(len(models)):
                with open(os.path.join(path, f"model_{i}"), "wb") as handle:
                    pickle.dump(train_states[i].params, handle)

        print(
            f"------------------------------------Elapsed time {elapsed_hms}------------------------------------"
        )
    return history
