import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.distributions import Categorical, Normal

class VectorizedLinear(nn.Module):
    def __init__(self, in_features, out_features, ensemble_size):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size

        self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features))
        self.bias = nn.Parameter(torch.empty(ensemble_size, 1, out_features))

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, batch_size, input_size]
        # weight: [ensemble_size, input_size, out_size]
        # out: [ensemble_size, batch_size, out_size]
        return x @ self.weight + self.bias


class VectorizedLinearHead(nn.Module):
    def __init__(self, in_features, out_features, action_dims, ensemble_size):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.ensemble_size = ensemble_size
        self.action_dims = action_dims 

        self.weight = nn.Parameter(torch.empty(ensemble_size, action_dims, in_features, out_features))
        self.bias = nn.Parameter(torch.empty(ensemble_size, action_dims, 1, out_features))

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(2))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, action_dims, batch_size, input_size]
        # weight: [ensemble_size, action_dims, input_size, out_size]
        # out: [ensemble_size, action_dims, batch_size, out_size]
        return x @ self.weight + self.bias




class DiscreteCritic(nn.Module):


    def __init__(self, state_dim, hidden_dim, action_dims, action_bins, ensemble_size, constrained_dim=10):

        super(DiscreteCritic, self).__init__()
        
        
        self.output_heads = VectorizedLinearHead(hidden_dim, action_bins, action_dims, ensemble_size)
        self.fc2 = VectorizedLinearHead(hidden_dim, hidden_dim, action_dims, ensemble_size)
        self.fc3 = VectorizedLinearHead(hidden_dim, hidden_dim, action_dims, ensemble_size)
       

        self.action_dims = action_dims
        self.ensemble_size = ensemble_size

        self.layer_norm = nn.LayerNorm(hidden_dim)
        
        if constrained_dim is not None:
            self.constrained_dim = min(action_dims-1, constrained_dim)
            self.action_idx = [[(j+1+x)%action_dims for x in range(self.constrained_dim)] for j in range(action_dims)]        
            self.fc1 = VectorizedLinearHead(state_dim+self.constrained_dim, hidden_dim, action_dims, ensemble_size)
            print(f'Constraining action inputs to {self.constrained_dim} dimensions\n')
        else:
            print(f'Not constraining action dimensions\n')
            print(f'Update actor every step \n')
            idxs = list(range(self.action_dims))
            self.action_idx = [ idxs[:i]+ idxs[i+1:] for i in range(self.action_dims)]
            self.fc1 = VectorizedLinearHead(state_dim+action_dims-1, hidden_dim, action_dims, ensemble_size)

    def forward(self, state, action):

       #x = state
        if self.action_dims == 1:
            x = state
        else:
            x_2 = action[:,self.action_idx]
            x_1 = state.unsqueeze(1).repeat(1,self.action_dims,1)
            x = torch.cat([x_1, x_2], dim=2).permute(1,0,2)

        residual = self.fc1(x) 
        x = residual
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = residual + x
        x = self.layer_norm(x)


        output = self.output_heads.forward(x).permute(0,2,1,3).squeeze(-1)
        return output
        


class Critic(nn.Module):


    def __init__(self, state_dim, hidden_dim, action_dims, action_bins, ensemble_size):

        super(Critic, self).__init__()

        self.fc1 = VectorizedLinear(state_dim+action_dims*action_bins, hidden_dim, ensemble_size)
        self.fc2 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)
        self.fc3 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)

        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.output_heads = VectorizedLinear(hidden_dim, action_dims, ensemble_size)
        self.action_dims = action_dims
        self.ensemble_size = ensemble_size

    def forward(self, state, action):

        
        
        x = torch.cat([state, action], dim=1)

        residual = self.fc1(x) 
        x = residual
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = residual + x
        x = self.layer_norm(x)

        output = self.output_heads.forward(x)


        return output
