from os.path import join
import os
import numpy as np
import torch
import sys
import mage.utils as utils
import mage.datasets as datasets
import time


class Parser(utils.Parser):
    dataset: str = 'pen-expert-v0'
    config: str = 'config.vqvae'

args = Parser().parse_args('train')
args = Parser().parse_args('plan')

# #######################
# ####### models ########
# #######################

env = datasets.load_environment(args.dataset)

dataset = utils.load_from_config(args.logbase, args.dataset, args.exp_name,
        'data_config.pkl')

gpt, gpt_epoch = utils.load_model(args.logbase, args.dataset, args.exp_name,
        epoch=args.gpt_epoch, device=args.device)
gpt.to(args.device)
prior, _ = utils.load_prior_model(args.logbase, args.dataset, args.exp_name,
        epoch=None, device=args.device)
prior.to(args.device)


if args.normalize:
    if args.use_action:
        gpt.set_padding_vector(torch.from_numpy(
                dataset.normalize_joined_single(np.zeros(gpt.transition_dim-1))
            ))
    else:
        gpt.set_padding_vector(torch.from_numpy(
                dataset.normalize_RandS(np.zeros(gpt.transition_dim-1))
            ))

#######################
####### dataset #######
#######################

discount = dataset.discount
observation_dim = dataset.observation_dim
action_dim = dataset.action_dim
@torch.no_grad()
def evaluator(test_times=10, dataset=None, gpt=None, prior=None, env=None, device=None, rtg_c=None):
    gpt.to(device)
    prior.to(device)
    all_reward = []
    all_path_length = []
    all_score = []

    discount = dataset.discount
    preprocess_fn = datasets.get_preprocess_fn(args.dataset)
    for tt in range(test_times):
        observation = env.reset()
        total_reward = 0
        discount_return = 0

        if "antmaze" in env.name:
            if dataset.disable_goal:
                observation = np.concatenate([observation, np.zeros([2], dtype=np.float32)])
                rollout = [np.concatenate([env.state_vector().copy(), np.zeros([2], dtype=np.float32)])]
            else:
                observation = np.concatenate([observation, env.target_goal])
                rollout = [np.concatenate([env.state_vector().copy(), env.target_goal])]
        elif "maze2d" in env.name:
            observation = np.concatenate([observation, env.get_target()])
            rollout = [np.concatenate([env.state_vector().copy(), env.get_target()])]
        else:
            rollout = []

        sequence_length = (dataset.sequence_length-1) * dataset.step
        trajectory = torch.zeros(sequence_length, gpt.observation_dim, device=args.device)
        action = torch.zeros(1, dataset.action_dim, device=args.device)
        T = env.max_episode_steps
        gpt.eval()
        prior.eval()

        rtg = rtg_c

        for t in range(T):
            trajectory = torch.cat([trajectory[1:, :], torch.zeros(1, gpt.observation_dim, device=args.device)], dim=0)
            observation = preprocess_fn(observation)

            if dataset.normalized_raw:
                observation = dataset.normalize_states(observation)

            if gpt.use_action:
                trajectory[0] = torch.cat([torch.from_numpy(observation).to(args.device), action[0]], dim=-1)
            else:
                trajectory[0] = torch.from_numpy(observation)

            if "antmaze" in env.name:
                if dataset.disable_goal:
                    state = np.concatenate([state, np.zeros([2], dtype=np.float32)])
                else:
                    state = np.concatenate([state, env.target_goal])
            elif "maze2d" in env.name:
                state = np.concatenate([state, env.get_target()])

            action = prior.sample(rtg, gpt, trajectory)

            if dataset.normalized_raw:
                act = dataset.denormalize_actions(action).cpu().numpy()
            else:
                act = action.cpu().numpy()

            next_observation, reward, terminal, _ = env.step(act[0])

            if "antmaze" in env.name:
                if dataset.disable_goal:
                    next_observation = np.concatenate([next_observation, np.zeros([2], dtype=np.float32)])
                else:
                    next_observation = np.concatenate([next_observation, env.target_goal])
            elif "maze2d" in env.name:
                next_observation = np.concatenate([next_observation, env.get_target()])

            total_reward += reward
            discount_return += reward* discount**(t)
            score = env.get_normalized_score(total_reward)

            if terminal or t == T-1: 
                all_reward.append(total_reward)
                all_path_length.append(t)
                all_score.append(score)
                break

            observation = next_observation
        
        print(f"Runs\t{tt}---\n returns: {total_reward:4.2f}   |  path_length: {t:5}  |  discount_return: {discount_return:4.2f}  |  score: {score:4.4f}")

    print('---------------------------', 'avg return:', np.mean(all_reward), 'stde return:', np.std(all_reward) / np.sqrt(len(all_reward)), "---------------------------\n",
          '---------------------------', 'avg score: ', np.mean(all_score),  'stde score: ', np.std(all_score) / np.sqrt(len(all_score)),  "---------------------------\n")
    return np.mean(all_reward), np.std(all_reward) / np.sqrt(len(all_reward)), np.mean(all_score), np.std(all_score) / np.sqrt(len(all_score))
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        