import os
from typing import List, Tuple

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".4"

import random
import time

import gym
import ml_collections
import numpy as np
import pandas as pd
from models import ContrastiveEncoder, CEAgent
from tqdm import trange
from utils import EnsembleBuffer, get_logger, make_env


# parameter
total_timesteps = int(1e6)
decay_steps  = int(8e5)
start_timesteps = 10000
ensemble_num = 4


def linear_decay(val, step):
    return val * (1 - 1 / decay_steps * step)


def eval_policy(agent: CEAgent,
                env: gym.Env,
                eval_episodes: int = 10) -> Tuple[float, float]:
    t1 = time.time()
    avg_reward = 0.
    for _ in range(eval_episodes):
        obs, done = env.reset(), False
        while not done:
            action = agent.sample_action(obs, eval_mode=True)  # (E, act_dim)
            obs, reward, done, _ = env.step(action[0])
            avg_reward += reward
    avg_reward /= eval_episodes
    return avg_reward, time.time() - t1


def train_and_evaluate(config: ml_collections.ConfigDict):
    start_time = time.time()
    timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())

    # logging
    exp_prefix = f"ce_t{config.temperature}_l{config.lmbda}_" + \
        f"mz{config.memory_size}_k{config.k}_h{config.horizon}"
    exp_name = f"s{config.seed}_{timestamp}"
    os.makedirs(f"logs/{exp_prefix}/{config.env_name}", exist_ok=True)
    os.makedirs(f"saved_models/{exp_prefix}/{config.env_name}", exist_ok=True)
    exp_info = f"# Running experiment for: {exp_prefix}_{exp_name}_{config.env_name} #"
    print("#" * len(exp_info) + f"\n{exp_info}\n" + "#" * len(exp_info))
    logger = get_logger(f"logs/{exp_prefix}/{config.env_name}/{exp_name}.log")
    logger.info(f"Config:\n{config}\n")

    # initialize the mujoco/dm_control environment
    envs = []
    for i in range(ensemble_num):
        env = make_env(config.env_name, config.seed + i)
        envs.append(env)
    eval_env = make_env(config.env_name, config.seed + 42)

    # environment parameter
    obs_dim = eval_env.observation_space.shape[0]
    act_dim = eval_env.action_space.shape[0]
    max_action = eval_env.action_space.high[0]

    # set random seed
    np.random.seed(config.seed)
    random.seed(config.seed)

    # vectorized SAC agent
    agent = CEAgent(ensemble_num=ensemble_num,
                    obs_dim=obs_dim,
                    act_dim=act_dim,
                    max_action=max_action,
                    seed=config.seed)

    # contrastive encoder
    encoder = ContrastiveEncoder(obs_dim=obs_dim,
                                 emb_dim=config.emb_dim,
                                 ensemble_num=ensemble_num,
                                 memory_size=config.memory_size,
                                 seed=config.seed + 10,
                                 k=config.k,
                                 temperature=config.temperature,
                                 contrast_batch_size=config.contrast_batch_size)

    # Replay buffer
    replay_buffer = EnsembleBuffer(ensemble_num=ensemble_num,
                                   obs_dim=obs_dim,
                                   act_dim=act_dim,
                                   memory_size=config.memory_size,
                                   memory_skip=config.memory_skip)
    buffer_iterator = replay_buffer.get_buffer_iterator()
    memory_iterator = replay_buffer.get_memory_iterator()
    contrast_iterator = replay_buffer.get_contrast_iterator(
        batch_size=config.contrast_batch_size,
        window=config.horizon)

    # reward for untrained agent
    eval_reward, eval_time = eval_policy(agent, eval_env)
    logs = [{"step": 0, "reward": eval_reward, "eval_time": eval_time}]

    # episode info
    ep_rewards = np.zeros((ensemble_num), dtype=np.float32)
    obses = np.array([env.reset() for env in envs])
    zero_rewards = np.zeros((ensemble_num, 256))
    rewards_str = ""
    traj_cnt = 0
    ep_step = 0

    # start training
    for t in trange(ensemble_num,
                    total_timesteps + ensemble_num,
                    ensemble_num):
        if t <= start_timesteps:
            actions = [env.action_space.sample() for env in envs]
        else:
            actions = agent.sample_action(obses)

        # interact with env
        for i in range(ensemble_num):
            next_obs, reward, done, info = envs[i].step(actions[i])
            done_bool = int(done) if "TimeLimit.truncated" not in info else 0
            replay_buffer.add(obses[i],
                              actions[i],
                              next_obs,
                              reward,
                              done_bool,
                              traj_cnt)
            ep_rewards[i] += reward
            obses[i] = next_obs
        ep_step += 1

        if t > start_timesteps:
            if t <= decay_steps:
                lmbda = max(0, linear_decay(config.lmbda, t))
                for _ in range(ensemble_num):
                    for _ in range(config.contrast_ups):
                        contrast_batch = next(contrast_iterator)
                        encoder_log_info = encoder.update(contrast_batch)
                    batch = next(buffer_iterator)
                    memory = next(memory_iterator)
                    if memory is not None:
                        intrinsic_rewards = encoder.compute_reward(
                            batch.next_observations, memory) * lmbda
                    else:
                        intrinsic_rewards = zero_rewards
                    log_info = agent.update(batch, intrinsic_rewards)
                    log_info["encoder_loss"] = encoder_log_info["encoder_loss"]
                    log_info["cos_sim"] = encoder_log_info["cos_sim"]
                    log_info["pos_sim"] = encoder_log_info["pos_sim"]
            else:
                for _ in range(config.ensemble_num):
                    batch = next(buffer_iterator)
                    log_info = agent.update(batch, zero_rewards)
                log_info["encoder_loss"] = 0
                log_info["cos_sim"] = 0
                log_info["pos_sim"] = 0

        if done or ep_step == config.earlystop:
            for i in range(ensemble_num):
                obs = envs[i].reset()
                obses[i] = obs
            traj_cnt += 1
            ep_step = 0
            rewards_str = ", ".join([f"{i:.2f}" for i in ep_rewards])
            ep_rewards[:] = 0

        if ((t > int(9.5e5) and (t % 5000 == 0))
                or (t <= int(9.5e5) and t % (2 * 5000) == 0)):
            eval_reward, eval_time = eval_policy(agent, eval_env)
            if t > start_timesteps:
                log_info["reward"] = eval_reward
                eval_reward = f"{eval_reward:.3f}"
                log_info.update({
                    "step": t,
                    "eval_time": eval_time,
                    "batch_reward": batch.rewards.mean(),
                    "batch_reward_max": batch.rewards.max(),
                    "batch_reward_min": batch.rewards.min(),
                    "batch_IR":", ".join([f"{i:.3f}" for i in intrinsic_rewards.mean(-1)]),
                    "batch_IRmax": ", ".join([f"{i:.3f}" for i in intrinsic_rewards.max(-1)]),
                    "batch_IRmin": ", ".join([f"{i:.3f}" for i in intrinsic_rewards.min(-1)]),
                    "time": (time.time() - start_time) / 60,
                    "q": log_info["qs"].mean(),
                })
                logger.info(
                    f"\n[T {t//1000}][{log_info['time']:.2f} min] "
                    f"eval_time: {eval_time:.2f}\n"
                    f"\teval_R: {eval_reward}\n"
                    f"\tep_rewards: {rewards_str}\n"
                    f"\tR: {log_info['batch_reward']:.3f}, "
                    f"Rmax: {log_info['batch_reward_max']:.3f}, "
                    f"Rmin: {log_info['batch_reward_min']:.3f}\n"
                    f"\tIR: {log_info['batch_IR']}\n"
                    f"\tIRmax: {log_info['batch_IRmax']}\n"
                    f"\tIRmin: {log_info['batch_IRmin']}\n"
                    f"\tq: {', '.join([f'{i:.3f}' for i in log_info['qs']])}\n"
                    f"\tq_loss: {', '.join([f'{i:.3f}' for i in log_info['critic_loss']])}\n"
                    f"\talphas: {', '.join([f'{i:.3f}' for i in log_info['alpha']])}\n"
                    f"\teLoss: {log_info['encoder_loss']:.3f}, "
                    f"pSim: {log_info['pos_sim']:.3f} "
                    f"cSim: {log_info['cos_sim']:.3f}\n"
                )
                for pop_key in ["critic_loss", "qs", "alpha", "batch_IR", "batch_IRmax", "batch_IRmin"]:
                    _ = log_info.pop(pop_key)
                logs.append(log_info)
            else:
                log_info = {
                    "step": t,
                    "reward": eval_reward,
                    "eval_time": eval_time,
                    "time": (time.time() - start_time) / 60,
                }
                logs.append(log_info)
                logger.info(
                    f"\n[T {t//1000}][{logs[-1]['time']:.2f} min] eval_reward: {eval_reward:.3f}, eval_time: {eval_time:.0f}\n"
                )

    # Save logs
    log_df = pd.DataFrame(logs)
    log_df.to_csv(
        f"logs/{exp_prefix}/{config.env_name}/{exp_name}.csv")

    # Save checkpoints
    agent.save(f"saved_models/{exp_prefix}/{config.env_name}")
