from collections import deque, namedtuple
import random
import time

import torch as torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import sys
import math
from statistics import NormalDist
import numpy as np
import gym
import d4rl
import matplotlib.pyplot as plt
import matplotlib
from create_dataset import *
from networks import *
from planning_agent import Planning_Agent
from timeit import default_timer as timer
from sac import SAC
import seaborn as sns
from sklearn.decomposition import PCA


device = torch.device('cuda')
train_q = True
train_f_b = True
train_f_m = True
normalize = True
batch_size = 128
episodes = 100
epochs = 40
sphere_norm=True

c_name = "med_hopper"
d4rl_name = "hopper-medium-v2"
env_name = "Hopper-v4"
var_scale = 1.2
kappa = 0.3
H = 4
N_traj = 100
beta = 0.0


def do_test(d4rl_name, env_name, agent_c_name, kappa, H, episodes, var_scale):
    env = gym.make(env_name)
    env.reset()

    agent_c = SAC(device, env.observation_space.shape[0], env.action_space)
    agent_c.policy.load_state_dict(torch.load("policies/sac_policy_" + agent_c_name + ".pth", "cuda"))
    agent_c.critic.load_state_dict(torch.load("q_networks/sac_q_network_" + agent_c_name + ".pth", "cuda"))

    planning_agent = Planning_Agent(d4rl_name, device, q_name2 = agent_c_name, epochs = epochs, batch_size = batch_size, lr = 0.001, tau = 0.001, discount = 0.99, H = H, N_traj = N_traj, beta = beta, kappa = kappa, normalize = normalize, var_scale=var_scale, sphere_norm=sphere_norm)
    print("Running " + d4rl_name + " + " + agent_c_name + "...")

    planning_agent.init(train_q, train_f_b, train_f_m)

    scores = []
    n_games=episodes
    scores_all = []
    

    for i in range(n_games):
        planning_agent.reset()
        score=0
        done=False
        observation=env.reset()
        total_steps=0
        while not done:
            action = planning_agent.plan_action_agent(observation, agent_c)
            new_observation, reward, done, _ = env.step(action)
            total_steps+=1
            score+=reward
            observation=new_observation
    
        scores.append(score)
        scores_all.append(score)
        print('-----------------------------------')
        print('episode ', i, ' score %.2f' % score)
        print('average score so far: ' + str(np.average(scores))+" +- "+str(np.std(scores)))

    return scores


scores = do_test(d4rl_name, env_name, c_name, kappa, H, episodes, var_scale)
print('score: '+str(np.average(scores))+" +- " + str(np.std(scores)))