import os
import time
import copy
import multiprocessing as mp
import numpy as np

from tqdm import tqdm

from sgcrl.utils.imports import instantiate_class, get_class, get_arguments
from sgcrl.gym_helpers import Bot, PytorchD4RLGymEnv
from sgcrl.models.quantizer import Tokenizer, special_tokens

def single_evaluation(model_db, env_yaml, cfg, logger, idx_model):
    env = PytorchD4RLGymEnv(get_class(env_yaml), **get_arguments(env_yaml))

    if cfg.render:
        env.set_render(True)

    idx_model = model_db.size("model") - 1
    assert idx_model<model_db.size("model")

    bot = model_db.get("model", idx_model)
    assert isinstance(bot, Bot)

    rewards = []
    subgoals_rewards = []
    scores = []
    lengths = []
    with tqdm(range(cfg.n_episodes), total=cfg.n_episodes, desc="Evaluating model") as pbar:
        for episode in pbar:
            episode = env.gather_episode(
                bot=bot, seed=episode, bot_args=cfg.bot_args
            )
            if "normalized_score" in episode.keys():
                scores.append(episode["normalized_score"])
            rewards.append(episode[cfg.reward_variable].sum().item())
            lengths.append(len(episode[cfg.reward_variable]))
            subgoals_rewards.append(max(int(bot.current_phase == 2), episode[cfg.reward_variable].sum().item()))

            pbar.set_postfix(reward=np.mean(rewards), subgoal_reward=np.mean(subgoals_rewards), length=np.mean(lengths))


    print(f"reward = {np.mean(rewards)}", end='')
    logger.add_scalar("avg_reward", np.mean(rewards), idx_model)
    logger.add_scalar("avg_reward_subgoal", np.mean(subgoals_rewards), idx_model)
    logger.add_scalar("avg_length", np.mean(lengths), idx_model)
    if "normalized_score" in episode.keys():
        logger.add_scalar("norm_score", np.mean(rewards), idx_model)
        print(f" normalized score {np.mean(scores)}", end='')
    print("")

def evaluation_loop(model_db, env_yaml, cfg, logger):
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    env = PytorchD4RLGymEnv(get_class(env_yaml), **get_arguments(env_yaml))
    print(env)
    seed = 1562
    _last_size=0   
    idx_model = 0
    while True:
        seed += 452
        while(model_db.size("model")==_last_size):
            time.sleep(1.0)
        bot = model_db.get("model", _last_size)
        assert isinstance(bot, Bot)
        _last_size += 1

        rewards = []
        subgoals_rewards = []
        scores = []
        lengths = []
        with tqdm(range(cfg.n_episodes), total=cfg.n_episodes, desc="Evaluating model") as pbar:
            for episode in pbar:
                episode = env.gather_episode(
                    bot=bot, seed=episode, bot_args=cfg.bot_args
                )
                if "normalized_score" in episode.keys():
                    scores.append(episode["normalized_score"])
                rewards.append(episode[cfg.reward_variable].sum().item())
                lengths.append(len(episode[cfg.reward_variable]))
                subgoals_rewards.append(max(int(bot.current_phase == 2), episode[cfg.reward_variable].sum().item()))

                pbar.set_postfix(reward=np.mean(rewards), subgoal_reward=np.mean(subgoals_rewards), length=np.mean(lengths))
        print(f"reward = {np.mean(rewards)}", end='')
        logger.add_scalar("avg_reward", np.mean(rewards), idx_model)
        logger.add_scalar("avg_reward_subgoal", np.mean(subgoals_rewards), idx_model)
        logger.add_scalar("avg_length", np.mean(lengths), idx_model)
        if "normalized_score" in episode.keys():
            logger.add_scalar("norm_score", np.mean(rewards), idx_model)
            print(f" normalized score {np.mean(scores)}", end='')
        print("")

        if _last_size == cfg.max_db_size + 1:  # + 1 to account for the model__0 (random bot)
            print("Evaluation done")
            exit(0)
            return
        
        idx_model += 1

def parallel_evaluation_loop(model_db, cfg, logger):
    process = mp.Process(
        target=evaluation_loop,
        args=(model_db, cfg.env, cfg, logger),
    )
    process.daemon=True
    process.start()

    return process

def serial_evaluation_loop(model_db, cfg, logger, only_last=True):
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    # Get env
    env = PytorchD4RLGymEnv(get_class(cfg.env), **get_arguments(cfg.env))

    # Get last size
    _last_size = model_db.size("model")
    if _last_size == 0:
        print('No models to evaluate')
        return
    
    # Get models idx
    model_idxs = list(range(_last_size))
    if only_last:
        model_idxs = model_idxs[-1:]

    # Evaluate models
    for model_idx in model_idxs:
        bot = model_db.get("model", model_idx)
        assert isinstance(bot, Bot)

        rewards = []
        subgoals_rewards = []
        scores = []
        lengths = []
        with tqdm(range(cfg.n_episodes), total=cfg.n_episodes, desc="Evaluating model") as pbar:
            for episode in pbar:
                episode = env.gather_episode(
                    bot=bot, seed=episode, bot_args=cfg.bot_args
                )
                if "normalized_score" in episode.keys():
                    scores.append(episode["normalized_score"])
                rewards.append(episode[cfg.reward_variable].sum().item())
                lengths.append(len(episode[cfg.reward_variable]))
                subgoals_rewards.append(max(int(bot.current_phase == 2), episode[cfg.reward_variable].sum().item()))

                pbar.set_postfix(reward=np.mean(rewards), subgoal_reward=np.mean(subgoals_rewards), length=np.mean(lengths))
        print(f"reward = {np.mean(rewards)}", end='')
        logger.add_scalar("avg_reward", np.mean(rewards), model_idx)
        logger.add_scalar("avg_reward_subgoal", np.mean(subgoals_rewards), model_idx)
        logger.add_scalar("avg_length", np.mean(lengths), model_idx)
        if "normalized_score" in episode.keys():
            logger.add_scalar("norm_score", np.mean(rewards), model_idx)
            print(f" normalized score {np.mean(scores)}", end='')
        print("")

if __name__ == '__main__':

    from sgcrl.data.dbs.on_disk import DiskPythonObjectDB, load_model
    from omegaconf import OmegaConf
    from sgcrl.utils.logger import TensorBoardLogger
    from sgcrl.models.dual_policy import DualPolicy

    # Create logger and model_db
    env_name = 'antmaze-extreme-diverse-v0'
    log_dir = './tests'
    logger = TensorBoardLogger(log_dir=f'{log_dir}/offline_evaluation', prefix='offline_evaluation', max_cache_size=1)
    model_db = DiskPythonObjectDB(f'{log_dir}/models')

    # Load models
    low_level_policy_subgoals = load_model(env_name,0,'cuda','hiql_subgoals')
    low_level_policy_goals = load_model(env_name,0,'cuda','iql_goals')
    transformer = load_model(env_name,0,'cuda','transformer')
    tokenizer = load_model(env_name,0,'cuda','iql_goals')

    dual_policy = DualPolicy(
        low_level_policy_subgoals,
        low_level_policy_goals, 
        transformer, 
        tokenizer, 
        special_tokens['SOS_TOKEN'], 
        special_tokens['EOS_TOKEN'], 
        keys_to_tokenize=cfg.keys_to_tokenize, 
        frame_relabellers=[instantiate_class(r) for r in (cfg.dual_policy_relabellers or [])]
    )
    m = copy.deepcopy(dual_policy).cpu()
    model_db.push("model", m)
    
    cfg = OmegaConf.create({
        'env': {'classname': 'gym.make', 'id': 'antmaze-extreme-diverse-v0'},
        'n_episodes': 50,
        'reward_variable': 'reward',
        'max_db_size': 32,
        'parallel': False,
        'bot_args': {'eval': True, 'stochastic': False}
    })
    
    serial_evaluation_loop(model_db,cfg,logger,only_last=True)