import os

import fire
import ray
import ray.cloudpickle as cloudpickle
from ray.rllib.agents.dqn.dqn_torch_policy import compute_q_values
import ray.rllib.utils.torch_ops as torch_ops
from ray.tune.registry import get_trainable_cls
import torch
import torch.nn.functional as F

from offline_rl.agents.cql.ensemble_cql_dqn_torch_policy import compute_q_values as ensemble_compute_q_values
from offline_rl.agents.load_custom_agents import load_custom_agents
from offline_rl.agents.registry import CUSTOM_AGENTS
from offline_rl.envs.line_env import LineEnv
from offline_rl.envs.load_custom_envs import load_custom_envs
from offline_rl.envs.maze_env import MazeEnv
from offline_rl.utils.space_utils import get_space_size, get_index_to_space_converter


def get_config_filepath_from_checkpoint(checkpoint_filepath):
    assert os.path.exists(checkpoint_filepath)
    config_dir = os.path.dirname(checkpoint_filepath)
    config_filepath = os.path.join(config_dir, "params.pkl")
    if not os.path.exists(config_filepath):
        config_filepath = os.path.join(config_dir, "../params.pkl")
    config_filepath = os.path.abspath(config_filepath)
    assert os.path.exists(config_filepath)
    return config_filepath


def load_config(filepath):
    with open(filepath, "rb") as f:
        config = cloudpickle.load(f)
        return config


def one_hot_transform(s, space):
    return torch_ops.one_hot(torch.tensor([s]), space)


def state_generator(env):
    num_states = get_space_size(env.observation_space)
    i2s = get_index_to_space_converter(env.observation_space)
    return (i2s(i) for i in range(num_states))


def ppo_forward(agent, obs):
    return agent.workers.local_worker().get_policy().model.forward({
        "obs_flat": obs,
    }, None, None)


def sac_forward(agent, obs, device="cuda"):
    obs = obs.to(device)
    model = agent.workers.local_worker().get_policy().model
    model_out_t, _ = model.forward({"obs": obs}, None, None)
    log_pis_t = F.log_softmax(model.get_policy_output(model_out_t), dim=-1)
    policy_t = torch.exp(log_pis_t)
    q_t = model.get_q_values(model_out_t)
    return policy_t, q_t


def dqn_forward(agent, obs, device="cuda"):
    obs = obs.to(device)
    batch = {"obs": obs}
    model = agent.workers.local_worker().get_policy().model
    q, _, _, _ = compute_q_values(agent, model, batch)
    pi = q.argmax(dim=1)
    return pi, q


def ensemble_dqn_forward(agent, obs, device="cuda"):
    obs = obs.to(device)
    batch = {"obs": obs}
    model = agent.workers.local_worker().get_policy().model
    q, _, _, _ = ensemble_compute_q_values(agent, model, batch)
    q = torch.stack(q, dim=0)
    mean_q = q.mean(dim=0)
    pi = mean_q.argmax(dim=1)
    std_q = torch.std(q, dim=0)
    return pi, q.mean(dim=0), std_q


def logits_to_pi(logits, states):
    best_action_indices = logits.argmax(dim=1).detach().cpu().numpy().tolist()
    pi = {s: best_action_indices[i] for i, s in enumerate(states)}
    return pi


def get_policy_q_function(env, algo_name, agent, transform_fn=None):
    states = [s for s in state_generator(env)]
    formatted_states = torch.cat([transform_fn(s, env.observation_space) for s in states])
    if algo_name == "PPO":
        logits, _ = ppo_forward(agent, formatted_states)
        pi = logits_to_pi(logits, states)
        return pi, None, {}
    elif algo_name == "SAC" or algo_name == "DiscreteCQLSAC":
        pi, q = sac_forward(agent, formatted_states)
        pi = logits_to_pi(logits, states)
        q = {s: q[i].detach().cpu().numpy().tolist() for i, s in enumerate(states)}
        return pi, q, {}
    elif algo_name == "DQN" or algo_name == "CQLDQN":
        pi, q = dqn_forward(agent, formatted_states)
        pi = {s: pi[i].detach().cpu().numpy().tolist() for i, s in enumerate(states)}
        q = {s: q[i].detach().cpu().numpy().tolist() for i, s in enumerate(states)}
        return pi, q, {}
    elif algo_name == "EnsembleCQLDQN":
        pi, q, std_q = ensemble_dqn_forward(agent, formatted_states)
        pi = {s: pi[i].detach().cpu().numpy().tolist() for i, s in enumerate(states)}
        q = {s: q[i].detach().cpu().numpy().tolist() for i, s in enumerate(states)}
        std_q = {s: std_q[i].detach().cpu().numpy().tolist() for i, s in enumerate(states)}
        return pi, q, dict(std_q=std_q)
    else:
        raise ValueError(f"{algo_name} not implemented.")


# pylint: disable=dangerous-default-value
def get_env(env_name, env_config={}):
    if env_name == "MazeEnv-v0":
        return MazeEnv()
    elif env_name == "LineEnv-v0":
        return LineEnv(**env_config)


def main(checkpoint_filepath, algo_name, local_mode=True):
    config_filepath = get_config_filepath_from_checkpoint(checkpoint_filepath)
    config = load_config(config_filepath)

    ray.init(local_mode=local_mode)

    env_name = config.get("env")
    cls = get_trainable_cls(algo_name)
    agent = cls(env=env_name, config=config)
    agent.restore(checkpoint_filepath)

    video_dir = os.path.join(os.path.dirname(checkpoint_filepath), "videos")
    os.makedirs(video_dir, exist_ok=True)
    viz_dir = os.path.join(os.path.dirname(checkpoint_filepath), "viz")
    os.makedirs(viz_dir, exist_ok=True)

    env = get_env(env_name, config.get("env_config", {}))
    pi, q, info = get_policy_q_function(env, algo_name, agent, transform_fn=one_hot_transform)
    env.render_policy(pi, filepath=os.path.join(viz_dir, f"{env_name.lower()}_{algo_name.lower()}_pi.png"))
    if q is not None:
        std_q = info.get("std_q", None)
        env.render_q_function(q,
                              std_q=std_q,
                              filepath=os.path.join(viz_dir, f"{env_name.lower()}_{algo_name.lower()}_q.png"))
        v = {s: max(q[s]) for s in q}
        env.render_value_function(v, filepath=os.path.join(viz_dir, f"{env_name.lower()}_{algo_name.lower()}_v.png"))


if __name__ == "__main__":
    load_custom_envs()
    load_custom_agents(CUSTOM_AGENTS)
    fire.Fire(main)
