import hydra
import tempfile
import subprocess
import os
import copy
import d4rl
import random
import time
import shutil

import torch.multiprocessing as mp

from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, random_split

from sgcrl.utils.seeding import set_seed
from sgcrl.utils.imports import instantiate_class, get_class
from sgcrl.data.dbs.on_disk import DiskPythonObjectDB, load_model, DiskPickleObjectDB
from sgcrl.utils.logger import TensorBoardLogger
from sgcrl.evaluation.d4rl_evaluation import parallel_evaluation_loop
from sgcrl.utils.conf import omegaconf_to_conf
from sgcrl.data.torch_datasets.datasets import ContrastiveTripletsDataset, ContrastiveTripletsDatasetVisual, TransformerDataset, PytorchEpisodeFrameDatasetComplexRelabel
from sgcrl.trainers import train_quantizer, train_transformer, train_low_hiql_goal, train_gcbc_goal, train_gcbc_subgoal, train_hiql_goal, train_low_hiql_subgoal
from sgcrl.utils.plotting import plot_maze, plot_godot, plot_visual_maze
from sgcrl.data.torch_datasets.relabel import NextTokenRelabeller, BatchRelabeller, TokenRelabeller
from sgcrl.models.quantizer import Tokenizer, special_tokens
from sgcrl.models.dual_policy import DualPolicy
from omegaconf import DictConfig, open_dict, ListConfig
from sgcrl.data.dbs.on_disk import PytorchOnDiskEpisodesDB
from sgcrl.evaluation.godot_evaluation import single_evaluation
from scripts.utils.application import start_server_with_model_idx
# from sgcrl.evaluation.godot_evaluation import start_server_with_model_idx

from sgcrl.utils.imports import instantiate_class, get_class, get_arguments
from sgcrl.gym_helpers import Bot, PytorchD4RLGymEnv
from sgcrl.utils.conf import omegaconf_to_conf
from sgcrl.evaluation.d4rl_evaluation import serial_evaluation_loop
import gym
from tqdm import tqdm
import torch
import numpy as np

def get_leaf_paths(root_folder, keys):
    """
    This function returns the relative path of all folders in the root folder that contain
    a file or folder which name is in the keys list
    """
    leaf_paths = []
    def dfs(current_folder, path):
        for item in os.listdir(current_folder):
            item_path = os.path.join(current_folder, item)
            if (item in keys) and (path not in leaf_paths):
                leaf_paths.append(os.path.join(*path))
            if os.path.isdir(item_path):
                dfs(item_path, path + [item])
    # For at least one level deep folders
    dfs(root_folder, ['./'])
    return leaf_paths

def evaluate(env, model, logger, model_idx, n_episodes, bot_args, reward_variable):
    rewards = []
    subgoals_rewards = []
    scores = []
    lengths = []
    with tqdm(range(n_episodes), total=n_episodes, desc="Evaluating model") as pbar:
        for episode in pbar:
            episode = env.gather_episode(
                bot=model, seed=episode, bot_args=bot_args
            )
            if "normalized_score" in episode.keys():
                scores.append(episode["normalized_score"])
            rewards.append(episode[reward_variable].sum().item())
            lengths.append(len(episode[reward_variable]))
            subgoals_rewards.append(max(int(model.current_phase == 2), episode[reward_variable].sum().item()))

            pbar.set_postfix(reward=np.mean(rewards), subgoal_reward=np.mean(subgoals_rewards), length=np.mean(lengths))
    print(f"reward = {np.mean(rewards)}", end='')
    logger.add_scalar("avg_reward", np.mean(rewards), model_idx)
    logger.add_scalar("avg_reward_subgoal", np.mean(subgoals_rewards), model_idx)
    logger.add_scalar("avg_length", np.mean(lengths), model_idx)
    logger.add_scalar('planning_time', episode['planning_time'], model_idx)
    logger.add_scalar('inference_time', episode['inference_time'], model_idx)
    logger.add_scalar('n_plannings', episode['n_plannings'], model_idx)
    logger.add_scalar('n_inferences', episode['n_inferences'], model_idx)
    logger.add_scalar('time_ratio', episode['planning_time']/episode['inference_time'], model_idx)
    logger.add_scalar('time_proportion', episode['planning_time']/(episode['inference_time']+episode['planning_time']), model_idx)
    logger.add_scalar('n_ratio', episode['n_plannings']/episode['n_inferences'], model_idx)
    logger.add_scalar('n_ratio', episode['n_plannings']/(episode['n_inferences']+episode['n_plannings']), model_idx)

    if "normalized_score" in episode.keys():
        logger.add_scalar("norm_score", np.mean(rewards), model_idx)
        print(f" normalized score {np.mean(scores)}", end='')
    print("")

def launch(cfg):

    # Get all the training paths (give a unique folder that works with a unique env/dataset name, like: that contains 0/ 1/ 2/ ...)
    base_path = f'{cfg.workdir}/save_best_models/{cfg.env_name}/{cfg.stitching}' # path that contains numbered seeds folder
    seeds = ['0', '1', '2', '3', '4', '5', '6', '7'] # list of seeds in the folder that will be evaluated
    models_paths = []
    for seed in seeds:
        seed_path = os.path.join(base_path,seed)
        leaf_paths = get_leaf_paths(seed_path,keys='models')
        model_path = os.path.join(seed_path,leaf_paths[0],'models')
        model_db = DiskPythonObjectDB(model_path)
        model = model_db.get_last('model')

        # plot tokenization
        os.makedirs(f"./sgcrl/algorithms/policies/qphil/experiments/offline_rl_experiment/backgrounds/{cfg.date}",exist_ok=True)
        background_file = f"./sgcrl/algorithms/policies/qphil/experiments/offline_rl_experiment/backgrounds/{cfg.date}/background_tokens.pickle"
        plot_maze(model.tokenizer, cfg.env_name, 'cpu', resolution=1, background_file=background_file, points=None, classes=None, labels=None)

        # model.tokenizer.norm = torch.tensor([[90, 55]])

        # Here set the training path
        training_path = os.path.join(seed_path,leaf_paths[0])
        os.makedirs(f'{training_path}/offline_evaluation/{cfg.date}')
        evaluation_logger = TensorBoardLogger(log_dir=f'{training_path}/offline_evaluation/{cfg.date}', prefix='offline_evaluation', max_cache_size=1)
        config_save_path = os.path.join(f'{training_path}/offline_evaluation/{cfg.date}', "config.yaml")
        OmegaConf.save(config=cfg, f=config_save_path)

        # Seeded evaluation
        env = PytorchD4RLGymEnv(gym.make, id=cfg.env_name)
        # env.set_render(True)
        set_seed(cfg.seed)
        evaluate(
            env=env, 
            model=model, 
            logger=evaluation_logger, 
            model_idx=0, 
            n_episodes=cfg.evaluation.n_episodes, 
            bot_args=cfg.evaluation.bot_args, 
            reward_variable=cfg.evaluation.reward_variable
        )

if __name__ == '__main__':
    @hydra.main(config_path="yaml",config_name="antmaze-extreme")
    def main(cfg):
        launch(cfg)
    main()