import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import copy
import os

from meta_test_algo.network import es_policy2, q_function
from meta_test_algo.base import base
from meta_test_algo.buffer import ReplayBuffer

class RandomProcess(object):
    def reset_states(self):
        pass

class AnnealedGaussianProcess(RandomProcess):
    def __init__(self, mu, sigma, sigma_min, n_steps_annealing):
        self.mu = mu
        self.sigma = sigma
        self.n_steps = 0

        if sigma_min is not None:
            self.m = -float(sigma - sigma_min) / float(n_steps_annealing)
            self.c = sigma
            self.sigma_min = sigma_min
        else:
            self.m = 0.
            self.c = sigma
            self.sigma_min = sigma

    @property
    def current_sigma(self):
        sigma = max(self.sigma_min, self.m * float(self.n_steps) + self.c)
        return sigma
    
class OrnsteinUhlenbeckProcess(AnnealedGaussianProcess):
    def __init__(self, theta, mu=0., sigma=1., dt=1e-2, x0=None, size=1, sigma_min=None, n_steps_annealing=1000):
        super(OrnsteinUhlenbeckProcess, self).__init__(mu=mu, sigma=sigma, sigma_min=sigma_min, n_steps_annealing=n_steps_annealing)
        self.theta = theta
        self.mu = mu
        self.dt = dt
        self.x0 = x0
        self.size = size
        self.reset_states()

    def sample(self):
        x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.current_sigma * np.sqrt(self.dt) * np.random.normal(size=self.size)
        self.x_prev = x
        self.n_steps += 1
        return x

    def reset_states(self):
        self.x_prev = self.x0 if self.x0 is not None else np.zeros(self.size)


class DDPG(base):
    def __init__(self, 
                 obs_dim, 
                 action_dim,
                 net_size,
                 latent_action_dim,
                 device,
                 esq_params,
                 scratch,
                 **kwargs):
        super().__init__(obs_dim,
                         action_dim,
                         net_size,
                         latent_action_dim,
                         device,
                         **kwargs)

        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']

        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = device


        self.policy_lr = esq_params['policy_lr']
        self.n_traj = esq_params['n_traj']
        self.n_updates = esq_params['n_updates']
        self.noise_sigma = esq_params['exp_noise_sigma']

        # 초기화 부분
        self.policy = es_policy2(obs_dim, action_dim, net_size, latent_action_dim).to(self.device)
        self.q_function1 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.q_function2 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.target_q_function1 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.target_q_function2 = q_function(obs_dim+action_dim, net_size).to(self.device)
        self.target_q_function1.load_state_dict(self.q_function1.state_dict())
        self.target_q_function2.load_state_dict(self.q_function2.state_dict())
        self.q_optimizer = optim.Adam(list(self.q_function1.parameters())+
                                      list(self.q_function2.parameters()),
                                      lr=1e-4)

        self.random_process = OrnsteinUhlenbeckProcess(size=action_dim,
                                                       theta=0.15,
                                                       mu=0.0,
                                                       sigma=0.2)
        self.epsilon = 1.0
        self.depsilon = self.epsilon / 10000

        if scratch:
            self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=self.policy_lr)
        else:
            self.policy_optimizer = optim.Adam(list(self.policy.first_layer.parameters()) + list(self.policy.last_layer.parameters()), lr=self.policy_lr)


        self.memory = ReplayBuffer(esq_params['buffer_size'], self.device)
        self.batch_size = esq_params['batch_size']
        self.discount = 0.99
        self.prev_best_return = -np.inf

        
    def update_q_function(self):
        if len(self.memory) < self.batch_size:
            return 0
        obs, actions, rewards, next_obs, dones = self.memory.sample(self.batch_size)

        with torch.no_grad():
            next_actions = self.policy(next_obs)
            next_q1 = self.target_q_function1(next_obs, next_actions)
            next_q2 = self.target_q_function2(next_obs, next_actions)
            next_q = torch.min(next_q1,next_q2)

        target_q = rewards + self.discount*(1-dones)*next_q
        td_error1 = target_q - self.q_function1(obs,actions)
        td_error2 = target_q - self.q_function2(obs,actions)
        q_loss1 = (td_error1 ** 2).mean()
        q_loss2 = (td_error2 ** 2).mean()
        q_loss = q_loss1+q_loss2
        # Q-function 업데이트
        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()

        self.soft_update_target()
        for _ in range(2):
            policy_loss = self.update_policy(obs)
        return q_loss.item()
    
    def soft_update_target(self, tau=0.005):
        for target_param, param in zip(self.target_q_function1.parameters(), self.q_function1.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        for target_param, param in zip(self.target_q_function2.parameters(), self.q_function2.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    def update_policy(self, obs):
        actions = self.policy.grad_action(obs)
        q_values = torch.min(self.q_function1(obs, actions),
                             self.q_function2(obs, actions))
        policy_loss = -q_values.mean()
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        return policy_loss.item()
    
    
    def collet_data_and_train_filter(self,env):
        total_step = 0
        for _ in range(self.n_traj):
            obs = env.reset()
            env_step = 0
            while env_step<self.max_path_length:
                env_step += 1
                action = self.select_action(obs)
                action += max(self.epsilon, 0)*self.random_process.sample()
                self.epsilon -= self.depsilon
                action = np.clip(action, -1., 1.)
                next_obs, reward, done, env_info = env.step(action)
                self.memory.add(obs, action, reward, next_obs, done)
                q_loss = self.update_q_function()
                obs = next_obs
                if done:
                    break
            self.random_process.reset_states()
            total_step += env_step
        return total_step