from collections import deque, namedtuple
import torch as torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np


def soft_clamp(x: torch.Tensor, _min=None, _max=None, device=torch.device("cuda")):
    x = x.to(device)
    if _max is not None:
        _max = _max.to(device)
        x = (_max - F.softplus(_max - x))
    if _min is not None:
        _min = _min.to(device)
        x = (_min + F.softplus(x - _min))
    return x

class MLP(nn.Module):
    def __init__(self, device, input_dim, output_dim, hidden_dim, num_hidden_layers):
        super(MLP, self).__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_hidden_layers)])
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.to(device)

    def forward(self, x):
        x = torch.relu(self.input_layer(x))
        for layer in self.hidden_layers:
            x = torch.relu(layer(x))
        x = self.output_layer(x)
        return x

class QNetwork(nn.Module):
    def __init__(self, device, lr, state_dims, actions_dims, neurons = 500):
        super(QNetwork, self).__init__()
        self.QNetwork = MLP(device, state_dims + actions_dims, 1, neurons, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.loss = nn.MSELoss()
        self.to(device)

    def forward(self, state, action):
        encoded_input = torch.cat((state, action), dim = -1)
        q_value = self.QNetwork(encoded_input)
        return q_value

class DynamicsMLP(nn.Module):
    def __init__(self, device, lr, state_dim, action_dim, neurons = 500):
        super(DynamicsMLP, self).__init__()
        self.state_dims= state_dim
        self.action_dim = action_dim
        self.dynamics = MLP(device, state_dim+action_dim, state_dim+1, neurons, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.loss = nn.MSELoss()
        self.device=device
        self.to(device)

    def forward(self, state, action):
        encoded_input = torch.cat((state, action), dim=-1)
        x = self.dynamics(encoded_input)
        return x[:,:-1], x[:,-1]

class StateEncoder(nn.Module):
    def __init__(self, device, state_dim, latent_dim, neurons = 500, sphere_norm=True):
        super(StateEncoder, self).__init__()
        self.encoder = MLP(device, state_dim, latent_dim, neurons, 2)
        self.sphere_norm=sphere_norm
        self.to(device)
    
    def forward(self, state):
        encoded_state = self.encoder(state)
        if self.sphere_norm:
            norm = torch.norm(encoded_state, p=2, dim=-1, keepdim=True)
            return encoded_state / norm
        else:
            return encoded_state

class StateDecoder(nn.Module):
    def __init__(self, device, state_dim, latent_dim, neurons = 500):
        super(StateDecoder, self).__init__()
        self.decoder = MLP(device, latent_dim, state_dim, neurons, 2)
        self.to(device)
    
    def forward(self, encoded_state):
        decoded_state = self.decoder(encoded_state)
        return decoded_state

class LatentDynamicsMLP(nn.Module):
    def __init__(self, device, latent_state_dim, action_dim, neurons = 500, sphere_norm=True):
        super(LatentDynamicsMLP, self).__init__()
        self.neurons = neurons
        self.dynamics_model = MLP(device, latent_state_dim + action_dim, latent_state_dim + 1, neurons, 2)
        self.device=device
        self.sphere_norm=sphere_norm
        self.to(device)

    def forward(self, latent_state, latent_action):
        encoded_input = torch.cat((latent_state, latent_action), dim = -1)
        encoded_next_state_reward = self.dynamics_model(encoded_input)
        encoded_next_state = encoded_next_state_reward[:,:-1]
        if self.sphere_norm:
            norm = torch.norm(encoded_next_state, p=2, dim=-1, keepdim=True)
            return encoded_next_state/norm, encoded_next_state_reward[:,-1]
        else:
            return encoded_next_state, encoded_next_state_reward[:,-1]
        

class LatentDynamicsModel(nn.Module):
    def __init__(self, device, lr, state_dim, latent_state_dim, action_dim, neurons = 500, sphere_norm=True):
        super(LatentDynamicsModel, self).__init__()
        self.neurons = neurons
        self.state_encoder = StateEncoder(device, state_dim, latent_state_dim, neurons, sphere_norm=sphere_norm)
        self.state_decoder = StateDecoder(device, state_dim, latent_state_dim, neurons)
        self.state_encoder_target = StateEncoder(device, state_dim, latent_state_dim, neurons, sphere_norm=sphere_norm)
        self.dynamics_model = LatentDynamicsMLP(device, latent_state_dim, action_dim, neurons, sphere_norm=sphere_norm)
        for param in self.state_encoder_target.parameters():
            param.requires_grad = False
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.device=device
        self.sphere_norm=sphere_norm
        self.to(device)
    
    def update_target_network(self, tau):
        curr_state_dict = self.state_encoder.state_dict()
        target_state_dict = self.state_encoder_target.state_dict()
        with torch.no_grad():
            for key in target_state_dict:
                target_state_dict[key].mul_(1-tau)
                target_state_dict[key].add_(curr_state_dict[key].mul(tau))
        self.state_encoder_target.load_state_dict(target_state_dict)

class LatentQMLP(nn.Module):
    def __init__(self, device, lr, latent_state_dim, action_dim, neurons = 500):
        super(LatentQMLP, self).__init__()
        self.QNetwork = MLP(device, latent_state_dim + action_dim, 1, neurons, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.to(device)

    def forward(self, latent_state, action):
        encoded_input = torch.cat((latent_state, action), dim = -1)
        q_value = self.QNetwork(encoded_input)
        return q_value

class LatentStochasticPolicyMLP(nn.Module):
    def __init__(self, device, lr, latent_state_dim, action_dim, neurons = 500):
        super(LatentStochasticPolicyMLP, self).__init__()
        self.policy_network = MLP(device, latent_state_dim, 2*action_dim, neurons, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.to(device)
        self.max_logstd = nn.Parameter((torch.ones((1, action_dim)).float() * 1).to(device), requires_grad=True)
        self.min_logstd = nn.Parameter((torch.ones((1, action_dim)).float() * -5).to(device), requires_grad=True)

    def forward(self, latent_state, var_scale = 1):
        output = self.policy_network(latent_state)
        mu, logstd = torch.chunk(output, 2, dim=-1)
        logstd = soft_clamp(logstd, self.min_logstd, self.max_logstd)
        expstd = torch.exp(logstd)
        dist = torch.distributions.Normal(mu, var_scale * expstd)
        return dist

class StochasticMLP(nn.Module):
    def __init__(self, device, lr, input_dims, output_dims, neurons = 500):
        super(StochasticMLP, self).__init__()
        self.output_dims = output_dims
        self.input_dims = input_dims
        self.neurons = neurons
        self.net = MLP(device, input_dims, 2*output_dims, neurons, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.loss = nn.GaussianNLLLoss()
        self.max_logstd = nn.Parameter((torch.ones((1, self.output_dims)).float() * 1).to(device), requires_grad=True)
        self.min_logstd = nn.Parameter((torch.ones((1, self.output_dims)).float() * -5).to(device), requires_grad=True)
        self.to(device)

    def forward(self, state, var_scale = 1):
        x = self.net(state)
        mu, logstd = torch.chunk(x, 2, dim=-1)
        logstd = soft_clamp(logstd, self.min_logstd, self.max_logstd)
        dist = torch.distributions.Normal(mu, var_scale * torch.exp(logstd))
        return dist

