import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import norm_state

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
action_offsite = torch.tensor([0, 0]).to(device)
action_scale = torch.tensor([1, 1]).to(device)

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, latent_dim=128):
        super(Actor, self).__init__()
        self.e1 = nn.Linear(state_dim, 750)
        self.e2 = nn.Linear(750, 750)

        self.mean = nn.Linear(750, latent_dim)
        self.log_std = nn.Linear(750, latent_dim)

        self.d1 = nn.Linear(latent_dim, 750)
        self.d2 = nn.Linear(750, 750)
        self.d3 = nn.Linear(750, action_dim)

        self.latent_dim = latent_dim
        self.device = device


    def forward(self, state,test=False):
        z = F.relu(self.e1(state))
        z = F.relu(self.e2(z))
        if test:
            mean = self.mean(z)
            u = self.decode(mean)
            return u
        else: 
            mean = self.mean(z)
            # Clamped for numerical stability 
            log_std = self.log_std(z).clamp(-4, 15)
            std = torch.exp(log_std)
            z = mean + std * torch.randn_like(std)
            u = self.decode(z)
            return u, mean, std

    def decode(self, z):
        a = F.relu(self.d1(z))
        a = F.relu(self.d2(a))
        return action_scale * torch.tanh(self.d3(a))+action_offsite

class SL_VAE(object):
    def __init__(self, args):
        self.actor = Actor(args.state_dim+args.h_state_dim, args.action_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=0.1*args.lr)

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        state=norm_state(state)
        action =self.actor(state,test=True)
        return action

    def train(self, replay_buffer, batch_size=64):
        # Sample replay buffer 
        state, h_state, action,_, _, _, _, _, _ = replay_buffer.sample(batch_size)

        state=torch.cat([state, h_state],axis=1)
        state=norm_state(state)
        recon, mean, std =self.actor(state, test=False)
        # action=action.clamp(-self.max_action, self.max_action)

        recon_loss = F.mse_loss(recon, action)
        KL_loss	= -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * KL_loss

        # Optimize the actor 
        self.actor_optimizer.zero_grad()
        vae_loss.backward()
        self.actor_optimizer.step()

        return {'sl_loss':vae_loss}

    def save(self, filename):
        torch.save(self.actor.state_dict(), filename + ".pth")
        

    def load(self, filename):
        self.actor.load_state_dict(torch.load(filename + ".pth"))
        
    