"""Run JSRL Exp"""

import os
from typing import Tuple

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".2"

import time
from collections import deque

import d4rl
import gym
import ml_collections
import numpy as np
import pandas as pd
from models import JSRLAgent
from tqdm import trange
from utils import ReplayBuffer, get_logger, load_ckpt, normalize_reward

TASK_Hs = {
    "hopper-random-v2": 100,
    "hopper-medium-v2": 570,
    "hopper-medium-replay-v2": 630,
    "halfcheetah-random-v2": 1000,
    "halfcheetah-medium-v2": 1000,
    "halfcheetah-medium-replay-v2": 1000,
    "walker2d-random-v2": 150,
    "walker2d-medium-v2": 945,
    "walker2d-medium-replay-v2": 700,
    "antmaze-large-play-v0": 880,
    "antmaze-large-diverse-v0": 820,
    "antmaze-medium-play-v0": 520,
    "antmaze-medium-diverse-v0": 510,
    "antmaze-umaze-v0": 220,
    "antmaze-umaze-diverse-v0": 440
}


def normalize_rewards(replay_buffer: ReplayBuffer, env_name: str):
    if 'v2' in env_name:
        # mujoco environments
        normalize_info_df = pd.read_csv('configs/minmax_traj_reward.csv',
                                        index_col=0).set_index('env_name')
        min_traj_reward, max_traj_reward = normalize_info_df.loc[
            env_name, ['min_traj_reward', 'max_traj_reward']]
        replay_buffer.rewards = replay_buffer.rewards / (
            max_traj_reward - min_traj_reward) * 1000
    else:
        # antmaze environments
        replay_buffer.rewards -= 1.0


def eval_policy(agent, env, eval_episodes=10) -> Tuple[float, float]:
    t1 = time.time()
    avg_reward = 0.
    for _ in range(eval_episodes):
        obs, done = env.reset(), False
        while not done:
            action = agent.sample_action(obs, eval_mode=True)
            obs, reward, done, _ = env.step(action)
            avg_reward += reward
    avg_reward /= eval_episodes
    d4rl_score = env.get_normalized_score(avg_reward) * 100
    return d4rl_score, time.time() - t1


def train_and_evaluate(config: ml_collections.ConfigDict):
    start_time = time.time()
    timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
    if "v2" in config.env_name:
        log_dir = f"logs/exp_mujoco_baseline/jsrl/{config.env_name}"
    else:
        log_dir = f"logs/exp_antmaze_baseline/jsrl/{config.env_name}"
    os.makedirs(log_dir, exist_ok=True)
    exp_name = f"jsrl_s{config.seed}_{timestamp}"
    exp_info = f"# Running experiment for: {exp_name}_{config.env_name} #"
    print("#" * len(exp_info) + f"\n{exp_info}\n" + "#" * len(exp_info))

    logger = get_logger(f"{log_dir}/{exp_name}.log")
    logger.info(f"Exp configurations:\n{config}")

    # initialize the environment
    env = gym.make(config.env_name)
    eval_env = gym.make(config.env_name)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    max_action = env.action_space.high[0]

    # initialize agent
    agent = JSRLAgent(obs_dim=obs_dim,
                      act_dim=act_dim,
                      max_action=max_action,
                      seed=config.seed,
                      expectile=config.expectile,
                      adv_temperature=config.adv_temperature,
                      max_timesteps=config.max_timesteps)

    # load checkpoint
    load_ckpt(agent, config.base_algo, config.env_name, cnt=200)
    logs = [{
        "step": 0,
        "reward": eval_policy(agent, env, config.eval_episodes)[0]
    }]

    # initialize the replay buffer
    replay_buffer = ReplayBuffer(obs_dim, act_dim, max_size=int(2e6))
    replay_buffer.convert_D4RL(d4rl.qlearning_dataset(env))
    normalize_rewards(replay_buffer, config.env_name)

    # fine-tuning
    obs, done = env.reset(), False
    episode_timesteps = 0

    # Curriculum Stages
    stages = config.stages
    stage_deque = deque(maxlen=config.ml)
    stage_h = TASK_Hs[config.env_name]
    guide_h = stage_h
    delta_h = int(stage_h / stages)

    stage_rewards, ma_reward, best_ma_reward = [], 0.0, 0.0

    for t in trange(1, config.max_timesteps + 1):
        episode_timesteps += 1
        if episode_timesteps <= guide_h:
            action = agent.sample_guide_action(obs, eval_mode=False)
        else:
            action = agent.sample_action(obs, eval_mode=False)
        next_obs, reward, done, info = env.step(action)
        if config.base_algo == "iql" or "antmaze" in config.env_name:
            reward = normalize_reward(config.env_name, reward)
        done_bool = float(done) if "TimeLimit.truncated" not in info else 0

        replay_buffer.add(obs, action, next_obs, reward, done_bool, flag=1.0)
        obs = next_obs

        if t > config.start_timesteps:
            batch = replay_buffer.sample(config.batch_size)
            log_info = agent.update(batch)
            log_info["online_ratio"] = batch.flags.sum() / len(batch.flags)
            sample_age = (replay_buffer.ptr - batch.idx).mean()

        if done:
            obs, done = env.reset(), False
            episode_timesteps = 0
            guide_h = stage_h

        if t % config.eval_freq == 0:
            eval_reward, eval_time = eval_policy(agent, eval_env,
                                                 config.eval_episodes)

            if "v0" in config.env_name or t % (config.eval_freq * 2) == 0:
                stage_deque.append(eval_reward)
                if len(stage_deque) >= stage_deque.maxlen:
                    ma_reward = sum([i for i in stage_deque
                                     ]) / stage_deque.maxlen
                    if len(stage_rewards) > 0:
                        best_ma_reward = max(stage_rewards)
                        if ma_reward > best_ma_reward * 0.95:
                            stage_h = max(stage_h - delta_h, 0)
                    stage_rewards.append(ma_reward)

            if t > config.start_timesteps:
                log_info.update({
                    "step": t,
                    "stage_h": stage_h,
                    "reward": eval_reward,
                    "eval_time": eval_time,
                    "time": (time.time() - start_time) / 60,
                    "sample_age": sample_age,
                    "ma_reward": ma_reward,
                    "best_ma_reward": best_ma_reward,
                    "buffer_size": replay_buffer.size,
                    "buffer_ptr": replay_buffer.ptr
                })
                logs.append(log_info)
                agent.logger(t, logger, log_info)
            else:
                logs.append({"step": t, "reward": eval_reward})
                logger.info(
                    f"\n[#Step {t}] eval_reward: {eval_reward:.2f}, eval_time: {eval_time:.2f}\n"
                )

    log_df = pd.DataFrame(logs)
    log_df.to_csv(f"{log_dir}/{exp_name}.csv")
