import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Union, Tuple, Optional

import numpy as np
from algos.dist_module import *
 

# from sklearn.decomposition import PCA

# Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3)
# Paper: https://arxiv.org/abs/1802.09477


class ActorProb(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorProb, self).__init__()

        self.l1 = nn.Linear(state_dim, 32)
        self.l2 = nn.Linear(32, 32)
        
       
        self.dist = TanhDiagGaussian(
            latent_dim=32,
            output_dim=action_dim,
            unbounded=True,
            conditioned_sigma=True,
            max_mu=1.0
        )

    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        
        return self.dist(a)


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()

        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1_f = F.relu(self.l2(q1))
        q1 = self.l3(q1_f)

        q2 = F.relu(self.l4(sa))
        q2_f = F.relu(self.l5(q2))
        q2 = self.l6(q2_f)
        return torch.cat([q1,q2],dim=1)
    
 

class VAE(nn.Module):
    def __init__(
            self,
            input_dim: int,
            output_dim: int,
            hidden_dim: int,
            latent_dim: int,
            max_action: Union[int, float],
            device: str = "cpu"
    ) -> None:
        super(VAE, self).__init__()
        self.e1 = nn.Linear(input_dim + output_dim, hidden_dim)
        self.e2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean = nn.Linear(hidden_dim, latent_dim)
        self.log_std = nn.Linear(hidden_dim, latent_dim)

        self.d1 = nn.Linear(input_dim + latent_dim, hidden_dim)
        self.d2 = nn.Linear(hidden_dim, hidden_dim)
        self.d3 = nn.Linear(hidden_dim, output_dim)

        self.max_action = max_action
        self.latent_dim = latent_dim
        self.device = torch.device(device)

        self.to(device=self.device)

    def forward(
            self,
            obs: torch.Tensor,
            action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        z = F.relu(self.e1(torch.cat([obs, action], 1)))
        z = F.relu(self.e2(z))

        mean = self.mean(z)
        # Clamped for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode(obs, z)

        return u, mean, std

    def decode(self, obs: torch.Tensor, z: Optional[torch.Tensor] = None) -> torch.Tensor:
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            z = torch.randn((obs.shape[0], self.latent_dim)).to(self.device).clamp(-0.5, 0.5)

        a = F.relu(self.d1(torch.cat([obs, z], 1)))
        a = F.relu(self.d2(a))
        return self.max_action * torch.tanh(self.d3(a))


class B2PD(object):
    def __init__(
            self,
            state_dim,
            action_dim,
            max_action,
            discount=0.99,
            lamda=0.005,
            policy_noise=0.2,
            noise_clip=0.5,
            policy_freq=2,
            xi=0.1,
            eta=0.4,
            H=10,
            device='cuda:0',
            

    ):
        self.device = device
        
         
        self.actor = ActorProb( state_dim, action_dim).to(self.device)
        # self.actor = Actor(state_dim, action_dim, max_action).to(self.device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic = Critic(state_dim, action_dim).to(self.device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        self.behavior_policy = VAE(
            input_dim=state_dim,
            output_dim=action_dim,
            hidden_dim=750,
            latent_dim=action_dim * 2,
            max_action=max_action,
            device=self.device
        )
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.behavior_policy_optim = torch.optim.Adam(self.behavior_policy.parameters(), lr=3e-4)
        self.max_action = max_action
        self.discount = discount
        self.lamda = lamda
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.total_it = 0
        self.H = H
        self.eta = eta   
        self.min_priority = 1.0  
        self.xi=xi
        self._is_auto_alpha = True
        self._target_entropy  =  -self.action_dim
        self._log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optim  = torch.optim.Adam([self._log_alpha], lr=3e-4)
        self._alpha = self._log_alpha.detach().exp()



    def adjust_eta(self):   
        self.eta = 0.97 * self.eta
 
    def get_achnor(self,state,batch_size):
        with torch.no_grad():
            s_in_repeat = torch.repeat_interleave(state, self.H, 0)
            sampled_actions = self.behavior_policy.decode(s_in_repeat)
            action_nei_q = self.critic_target(s_in_repeat, sampled_actions)   
            action_nei_q = torch.min(action_nei_q,dim=1).values
            action_nei_q_shaped = action_nei_q.reshape(batch_size, -1)
            _, new_sort_index = torch.sort(action_nei_q_shaped, descending=False, dim=-1)   
            max_index = new_sort_index[:, -1].cpu().data.numpy() + np.arange(batch_size) *self.H
             
            anchor = sampled_actions[max_index]  # local optimal actions
            return anchor
        

    def online_PT(self,state, action):  #action value prior 
        recon, mean, std = self.behavior_policy(state, action)
        recon_loss = F.mse_loss(recon, action)
        KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
       
        if self.xi==0:
            vae_loss = recon_loss + KL_loss  
        else:
            avpd = -self.critic(state, recon).mean()
            vae_loss = recon_loss + KL_loss  + avpd * self.xi  
        self.behavior_policy_optim.zero_grad()
        vae_loss.backward()
        self.behavior_policy_optim.step()
     
    def actforward(
        self,
        obs: torch.Tensor,
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        dist = self.actor(obs)
        if deterministic:
            squashed_action, raw_action = dist.mode()
        else:
            squashed_action, raw_action = dist.rsample()
        log_prob = dist.log_prob(squashed_action, raw_action)
        return squashed_action, log_prob
    
    def select_action(
        self,
        obs: np.ndarray,
        deterministic: bool = False
    ) -> np.ndarray:
        with torch.no_grad():
            obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device)
            action, _ = self.actforward(obs, deterministic)
        return action.cpu().data.numpy().flatten()
    
     

    def train(self, replay_buffer, batch_size=256):
      
        if self.total_it%20000==0:
            self.adjust_eta()
        self.total_it += 1
        state, action, next_state, reward, not_done = replay_buffer.sample()
       
        with torch.no_grad():   # SDA  Noise
            dist = self.actor_target(next_state) 
            sda_n =  (0.2/torch.exp(dist.scale)**1.5).mean()   
            noise = (
                    torch.randn_like(action) *sda_n
            ).clamp(-self.noise_clip, self.noise_clip)
            next_actions, raw_action = dist.rsample()
            next_log_probs = dist.log_prob(next_actions,raw_action)
            next_actions = (
                    next_actions+ noise
            ).clamp(-self.max_action, self.max_action)
            # Compute the target Q value
            target_Q = self.critic_target(next_state, next_actions) 
            target_Q = torch.min(target_Q,dim=1,keepdim=True).values- self._alpha * next_log_probs #
            target_Q = reward + not_done * self.discount * target_Q
 
        current_Q = self.critic(state, action)
      
        # # Compute critic loss
        critic_loss = F.mse_loss(current_Q, target_Q)  
        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # update behavior policy
      
        self.online_PT(state, action)
        
        # update actor
        a, log_probs = self.actforward(state)
        qs = self.critic(state, a) 
        high_action= self.get_achnor(state,batch_size)       
        max_ent_loss = -qs.mean() + self._alpha * log_probs.mean()  
         
        bppd_loss = F.kl_div(F.log_softmax(a),F.softmax(high_action),reduce='mean')
        actor_loss = max_ent_loss + bppd_loss*self.eta 
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        if self._is_auto_alpha:
            log_probs = log_probs.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_probs).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self._alpha = torch.clamp(self._log_alpha.detach().exp(), 0.0, 1.0)
         
        # Update the frozen target models
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.lamda * param.data + (1 - self.lamda) * target_param.data)

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.lamda * param.data + (1 - self.lamda) * target_param.data)
      

    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")

        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
        self.actor_target = copy.deepcopy(self.actor)

