import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from meta_test_algo.network import es_policy2, q_function
from meta_test_algo.base import base
from meta_test_algo.buffer import ReplayBuffer
class ESQ(base):
    def __init__(self, 
                 obs_dim, 
                 action_dim,
                 net_size,
                 latent_action_dim,
                 device,
                 esq_params,
                 scratch,
                 action_noise,
                 **kwargs):
        super().__init__(obs_dim,
                         action_dim,
                         net_size,
                         latent_action_dim,
                         device,
                         **kwargs)

        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = device

        self.action_noise = action_noise

        self.policy_lr = esq_params['policy_lr']
        self.n_traj = esq_params['n_traj']
        self.n_updates = esq_params['n_updates']
        self.noise_sigma = esq_params['exp_noise_sigma']
        self.w = esq_params['w']
        self.reset_step = esq_params['reset_step']

        # 초기화 부분
        self.policy = es_policy2(obs_dim, action_dim, net_size, latent_action_dim,self.w).to(self.device)
        self.q_function1 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.q_function2 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.target_q_function1 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.target_q_function2 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.target_q_function1.load_state_dict(self.q_function1.state_dict())
        self.target_q_function2.load_state_dict(self.q_function2.state_dict())
        self.q_optimizer = optim.Adam(list(self.q_function1.parameters())+
                                      list(self.q_function2.parameters()),
                                      lr=1e-4)

        if scratch:
            self.task_filter_params = list(self.policy.parameters())
        else:
            self.task_filter_params = list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters())
        self.head_param = nn.utils.parameters_to_vector(self.task_filter_params).detach().cpu().numpy()  # shape: (total param,)
        self.param_shape = self.head_param.shape

        self.policy_optimizer = optim.Adam(self.task_filter_params,lr=self.policy_lr)

        self.memory = ReplayBuffer(esq_params['buffer_size'], self.device)
        self.batch_size = esq_params['batch_size']
        self.discount = 0.99
        self.steps = 0
        self.prev_best_return = -np.inf
        self.update_steps = 0

    def reset_update_policy(self):
        self.policy.reset(self.device)
        self.head_param = nn.utils.parameters_to_vector(self.task_filter_params).detach().cpu().numpy()

    def evaluate_head(self,env,action_noise = False):
        obs = env.reset()
        env_step = 0
        total_reward = 0
        while env_step<self.max_path_length:
            env_step += 1
            action = self.select_action(obs)
            if action_noise:
                noise = np.random.normal(0,0.1,size=action.shape)
                action += noise
                action = np.clip(action,env.action_space.low,env.action_space.high)
            next_obs, reward, done, env_info = env.step(action)
            total_reward += reward
            self.memory.add(obs, action, reward, next_obs, done)
            obs = next_obs
            if done:
                break
        return total_reward,env_step
        
    def update_q_function(self):
        if len(self.memory) < self.batch_size:
            return 0
        self.update_steps += 1
        obs, actions, rewards, next_obs, dones = self.memory.sample(self.batch_size)

        with torch.no_grad():
            next_actions = self.policy(next_obs)
            next_q1 = self.target_q_function1(next_obs, next_actions)
            next_q2 = self.target_q_function2(next_obs, next_actions)
            next_q = torch.min(next_q1,next_q2)

        target_q = rewards + self.discount*(1-dones)*next_q
        td_error1 = target_q - self.q_function1(obs,actions)
        td_error2 = target_q - self.q_function2(obs,actions)
        q_loss1 = (td_error1 ** 2).mean()
        q_loss2 = (td_error2 ** 2).mean()
        q_loss = q_loss1+q_loss2
        # Q-function 업데이트
        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()

        # print('critic q:',torch.min(self.q_function1(obs,actions),self.q_function2(obs,actions)).mean())
        self.soft_update_target()
        for _ in range(2):
            policy_loss = self.update_policy(obs,optimizer=False)
        return q_loss.item()
    
    def soft_update_target(self, tau=0.005):
        for target_param, param in zip(self.target_q_function1.parameters(), self.q_function1.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        for target_param, param in zip(self.target_q_function2.parameters(), self.q_function2.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    def update_policy(self, obs, optimizer=False):
        # for param in self.policy.shared_layer.parameters():
        #     param.requires_grad = True
        actions = self.policy.grad_action(obs)
        q_values = torch.min(self.q_function1(obs, actions),
                             self.q_function2(obs, actions))
        # print('policy q:',q_values.mean())
        policy_loss = -q_values.mean()
        if optimizer:
            self.policy_optimizer.zero_grad()
            policy_loss.backward()
            self.policy_optimizer.step()
            self.head_param = nn.utils.parameters_to_vector(self.task_filter_params).detach().cpu().numpy()
        else:
            grad = torch.autograd.grad(policy_loss, self.task_filter_params)
            grad_vector = torch.cat([g.view(-1) for g in grad]).detach().cpu().numpy()

            task_filter_params = nn.utils.parameters_to_vector(self.task_filter_params).detach().cpu().numpy()
            task_filter_params -= self.policy_lr*grad_vector
            self.head_param = task_filter_params
            self.update_task_filter_param(self.head_param)
        return policy_loss.item()
    
    def update_task_filter_param(self,vector):
        # update the task filter parameters with the new vector
        nn.utils.vector_to_parameters(
            torch.tensor(vector, dtype=torch.float32).to(self.device),
            self.task_filter_params
        )
    
    def es_adapt_head(self, env):
        self.steps += 1
        total_steps = 0
        # obtaining data from the environment with perturbed parameters
        temp = []
        if self.steps % self.reset_step == 0:
            print('Reset')
            self.reset_update_policy()
        for _ in range(self.n_traj):
            if not self.action_noise:
                epsilon = np.random.randn(*self.param_shape)
                perturbed_params = self.head_param + self.noise_sigma*epsilon
                self.update_task_filter_param(perturbed_params)
            returns, steps = self.evaluate_head(env,action_noise=self.action_noise)
            total_steps += steps
            temp.append(returns)
        print(temp)

        # returning to the original parameters
        self.update_task_filter_param(self.head_param)
        
        for _ in range(self.n_updates):
            q_loss = self.update_q_function()
        return total_steps
    
    def collet_data_and_train_filter(self,env):
        env_steps = self.es_adapt_head(env)
        return env_steps