import os
import numpy as np
import torch
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from rltorch.memory import MultiStepMemory, PrioritizedMemory

from model import TwinnedQNetwork, GaussianPolicy, RandomizedEnsembleNetwork
from utils import grad_false, hard_update, soft_update, to_batch, update_params, RunningMeanStats

from collections import deque
import itertools
import math
import random

class SacAgent:

    def __init__(self,  num_steps=3000000, batch_size=256,
                 lr=0.0003, hidden_units=[256, 256], memory_size=1e6,
                 gamma=0.99, tau=0.005, entropy_tuning=True, ent_coef=0.2,
                 multi_step=1, per=False, alpha=0.6, beta=0.4,
                 beta_annealing=0.0001, grad_clip=None, updates_per_step=1,
                 start_steps=10000, log_interval=10, target_update_interval=1,
                 eval_interval=1000, cuda=0, seed=0,
                 # added by TH 20210707
                 eval_runs=1, huber=0, layer_norm=0,
                 method=None, target_entropy=None, target_drop_rate=0.0, critic_update_delay=1,
                 robot = 'Ant'):
        # self.env = env

        torch.manual_seed(seed)
        np.random.seed(seed)
        # self.env.seed(seed)
        torch.backends.cudnn.deterministic = True  # It harms a performance.
        torch.backends.cudnn.benchmark = False

        self.method = method
        self.critic_update_delay = critic_update_delay
        self.target_drop_rate = target_drop_rate

        self.device = torch.device("cuda:" + str(cuda) if torch.cuda.is_available() else "cpu")
        if robot == 'Ant':
            self.ins_dim = 8
            self.state_dim = 29  # qpos 15 qvel 14,, witout xy
            self.action_dim = 8

        elif robot == 'Point' or robot == 'Point-v2':
            self.ins_dim = 8
            self.state_dim = 6  # qpos 3 qvel 3, witout xy
            self.action_dim = 2
        elif robot == 'Swimmer' or robot == 'Swimmer-v2':
            self.ins_dim = 8
            self.state_dim = 8 # qpos 5 qvel 5,, witout xy
            self.action_dim = 2

        # policy
        # print()
        self.policy = GaussianPolicy(
            self.state_dim + self.ins_dim,
            self.action_dim,
            hidden_units=hidden_units)




        # # Q functions
        # kwargs_q = {"num_inputs": self.state_dim + self.ins_dim,
        #             "num_actions": self.action_dim,
        #             "hidden_units": hidden_units,
        #             "layer_norm": layer_norm,
        #             "drop_rate": self.target_drop_rate}
        # if self.method == "redq":
        #     self.critic = RandomizedEnsembleNetwork(**kwargs_q)
        #     self.critic_target = RandomizedEnsembleNetwork(**kwargs_q)
        # else:
        #     self.critic = TwinnedQNetwork(**kwargs_q)
        #     self.critic_target = TwinnedQNetwork(**kwargs_q)
        # if self.target_drop_rate <= 0.0:
        #     self.critic_target = self.critic_target.eval()
        # # copy parameters of the learning network to the target network
        # hard_update(self.critic_target, self.critic)
        # # disable gradient calculations of the target network
        # grad_false(self.critic_target)
        #
        #
        #
        #
        #
        # self.log_dir = log_dir
        # self.model_dir = os.path.join(log_dir, 'model')
        # self.summary_dir = os.path.join(log_dir, 'summary')
        # if not os.path.exists(self.model_dir):
        #     os.makedirs(self.model_dir)
        # if not os.path.exists(self.summary_dir):
        #     os.makedirs(self.summary_dir)
        #
        # # self.writer = SummaryWriter(log_dir=self.summary_dir)
        # # self.train_rewards = RunningMeanStats(log_interval)
        #
        # self.steps = 0
        # self.learning_steps = 0
        # self.episodes = 0
        # self.num_steps = num_steps
        # self.tau = tau
        # self.per = per
        # self.batch_size = batch_size
        # self.start_steps = start_steps
        # self.gamma_n = gamma ** multi_step
        # self.entropy_tuning = entropy_tuning
        # self.grad_clip = grad_clip
        # self.updates_per_step = updates_per_step
        # self.log_interval = log_interval
        # self.target_update_interval = target_update_interval
        # self.eval_interval = eval_interval
        # #
        # self.eval_runs = eval_runs
        # self.huber = huber
        # self.multi_step = multi_step



    def explore(self, state):
        # act with randomness
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            action, _, _ = self.policy.sample(state)
        return action.cpu().numpy().reshape(-1)

    def exploit(self, state):
        # act without randomness
        # state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        state = state.unsqueeze(0)
        with torch.no_grad():
            _, _, action = self.policy.sample(state)
        return action.cpu().numpy().reshape(-1)


    def load_models(self, ep=None,name='Ant'):
        if ep is not None:
            if name=='Ant':
                self.policy.load('./models/' + '{:08d}_policy_.pth'.format(ep))
            else:
                self.policy.load('./models/' + '{:08d}_policy_{}.pth'.format(ep,name))
        else:
            # self.policy.load('./models/' + '{:08d}_policy_.pth'.format(5027))
            self.policy.load('./models/' + 'policy.pth')

        # self.critic.load('low_level_agent/save_model/' + 'critic.pth')
        # self.critic_target.load('low_level_agent/save_model/' + 'critic_target.pth')


    def to_cuda(self, device):
        self.policy.to(device)
        # self.critic.to(self.device)






    def __del__(self):
        # self.writer.close()
        # self.env.close()
        pass


