class SelfAttentionLayer(nn.Module):

    def __init__(self, feature_size, ensemble_size, mask):
        super(SelfAttentionLayer,self).__init__()

        self.feature_size = feature_size
        self.ensemble_size = ensemble_size
        self.output_dim = 1 #ensemble_size #1
        self.key = VectorizedLinear(self.ensemble_size, self.output_dim, 1)
        self.query = VectorizedLinear(self.ensemble_size, self.output_dim, 1)
        self.value = VectorizedLinear(self.ensemble_size, self.output_dim, 1)
       #self.key = VectorizedLinear(feature_size, self.output_dim, ensemble_size)
       #self.query = VectorizedLinear(feature_size, self.output_dim, ensemble_size)
       #self.value = VectorizedLinear(feature_size, self.output_dim, ensemble_size)

        self.attn_mask = mask
        self.multihead_attn = nn.MultiheadAttention(self.output_dim,num_heads=1,batch_first=True)


    def forward(self, x):

        batch_size = x.shape[0]
        action_dims = x.shape[2]
        x = x.permute(0,3,2,1).flatten(1,2)

       #x = x.reshape(batch_size,-1,self.feature_size)
        key = self.key(x)
        queries = self.query(x)
        value = self.value(x)

        output, attention_weights = self.multihead_attn(queries,key,value, attn_mask=self.attn_mask)
       #output = output.reshape(batch_size,self.ensemble_size,action_dims,-1)
        output = output.reshape(batch_size,1,action_dims,-1)
       #print(output.shape,attention_weights.shape)

    def mask_mat(self):
       #mask = torch.ones(action_dims*action_bins,action_dims*action_bins).to('cuda:0')
       #eye_mat = torch.eye(action_bins).to('cuda:0')
       #
       #for i in range(action_dims):
       #    idx = i*action_bins
       #    mask[idx:idx+action_bins,idx:idx+action_bins] = eye_mat


       #self.attention = SelfAttentionLayer(action_dims*action_bins, ensemble_size, mask)



class MLPResidualLayer(nn.Module):
    def __init__(self, dim):
        super(MLPResidualLayer, self).__init__()

        self.fc1 = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, dim)

    def forward(self, x):
        residual = x
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return residual + x


class DecoupledQNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim, num_states, num_heads):
        super(DecoupledQNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.resnet = MLPResidualLayer(hidden_dim)
        self.dropout = nn.Dropout(0.2)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.output_heads = VectorizedLinear(hidden_dim, num_states, num_heads)
       #self.output_heads = nn.Linear(hidden_dim, num_states*num_heads)
        self.num_heads = num_heads
        self.state_dim = state_dim
        self.hidden_dim = hidden_dim
        self.num_states = num_states


    def forward(self, x):

        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        vals = self.output_heads.forward(x).transpose(0, 1)
        return vals






class VectorisedMLPResidualLayer(nn.Module):
    def __init__(self, dim, ensemble_size):
        super(VectorisedMLPResidualLayer, self).__init__()

        self.fc1 = VectorizedLinear(dim, dim, ensemble_size)
        self.fc2 = VectorizedLinear(dim, dim, ensemble_size)

    def forward(self, x):
        residual = x
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return residual + x


class EnsembleDecoupledDuellingQNetwork(nn.Module):
    def __init__(self, state_dim, hidden_dim, num_actions, num_heads, ensemble_size):
        super(EnsembleDecoupledDuellingQNetwork, self).__init__()
        self.fc1 = VectorizedLinear(state_dim, hidden_dim, ensemble_size)
        self.fc2 =  VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)
        self.value_head = VectorizedLinearHead(hidden_dim, 1, ensemble_size, num_heads)
        self.advantage_head = VectorizedLinearHead(hidden_dim, num_actions, ensemble_size, num_heads)
        self.num_heads = num_heads
        self.ensemble_size = ensemble_size

    def forward(self, x):
        if len(x.shape) == 2:
            x = x.unsqueeze(dim=1).repeat(1, self.ensemble_size, 1)
        x = x.transpose(0, 1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = x.unsqueeze(dim=1).repeat(1, self.num_heads, 1, 1)

        value = self.value_head.forward(x).transpose(1, 2).transpose(0, 1)
        advantage = self.advantage_head.forward(x).transpose(1, 2).transpose(0, 1)

        q_value = value + (advantage - advantage.mean(dim=-1,keepdim=True))

        return q_value


class VectorizedAttentionHead(nn.Module):
    def __init__(self, in_features, out_features, action_dims, attn_head, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
       #self.ensemble_size = ensemble_size
        self.action_dims = action_dims
        self.attn_head = attn_head

        self.weight = nn.Parameter(torch.empty(attn_head, action_dims, in_features, out_features))

        if bias:
            self.bias = nn.Parameter(torch.empty(attn_head, action_dims, 1, out_features))
        else:
            self.bias = None

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(2))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # input: [ensemble_size, action_dims, batch_size, input_size]
        # weight: [ensemble_size, action_dims, input_size, out_size]
        # out: [ensemble_size, action_dims, batch_size, out_size]
        if self.bias is not None:
            return x @ self.weight + self.bias
        else:
            return x @ self.weight


class AttentionCritic(nn.Module):


    def __init__(self, state_dim, hidden_dim, action_dims, action_bins, ensemble_size, attn_head, batch_size, enc_dim=None):

        super(AttentionCritic, self).__init__()

        if enc_dim is None:
            self.encoder_dim = hidden_dim
        else:
            self.encoder_dim = enc_dim


        self.encoder = VectorizedLinear(state_dim+1, hidden_dim, action_dims)

        self.attn_dim = hidden_dim // attn_head
        self.key_extractor = VectorizedAttentionHead(hidden_dim, self.attn_dim, action_dims, attn_head)
        self.query_extractor = VectorizedAttentionHead(hidden_dim, self.attn_dim, action_dims, attn_head)
        self.value_extractor = VectorizedAttentionHead(hidden_dim, self.attn_dim, action_dims, attn_head)

        self.fc1 = VectorizedLinearHead(state_dim+action_dims-1, self.encoder_dim,
                                        action_dims, ensemble_size)
        self.fc2 = VectorizedLinearHead(self.encoder_dim, hidden_dim, action_dims, ensemble_size)
        self.fc3 = VectorizedLinearHead(hidden_dim, hidden_dim, action_dims, ensemble_size)

       #self.value_input  = VectorizedLinearHead(state_dim, hidden_dim, action_dims, ensemble_size)
       #self.value_input  = VectorizedLinear(state_dim, hidden_dim, ensemble_size)
       #self.v_fc2 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)
       #self.v_fc3 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)



        self.output_heads = VectorizedLinearHead(hidden_dim, action_bins, action_dims, ensemble_size)
       #self.output_heads = VectorizedLinearHead(hidden_dim, 1, action_dims, ensemble_size)
       #self.output_heads = VectorizedLinear(hidden_dim, action_dims, ensemble_size)

        self.action_dims = action_dims
        self.ensemble_size = ensemble_size
        self.batch_size = batch_size

       #self.action_idxs = list(range(self.action_dims))
       #self.idxs = [[i]*(self.action_dims-1) for i in self.action_idxs]
       #self.attention_idx = [ self.action_idxs[:i]+ self.action_idxs[i+1:] for i in range(self.action_dims)]
        self.eye = torch.eye(action_dims,device='cuda:0').unsqueeze(0).unsqueeze(0).repeat(1,batch_size,1,1)
        self.eye[self.eye==1] = - torch.inf

    def forward(self, state, action, ret_advantage=False):

    #   x = state
        batch_size = state.shape[0]
        if self.action_dims == 1:
            x = state
        else:
            action_idx = [ [i] for i in range(self.action_dims)]
            x_2 = action[:,action_idx]
            x_1 = state.unsqueeze(1).repeat(1,self.action_dims,1)
            x = torch.cat([x_1, x_2], dim=2).permute(1,0,2)


        encode = torch.relu(self.encoder(x))

        keys = self.key_extractor(encode).permute(0,2,3,1)
        query = self.query_extractor(encode).permute(0,2,1,3)
        values = torch.relu(self.value_extractor(encode)).permute(0,2,1,3)

        attend_logits =  query @ keys
       #selected_logits = attend_logits[:,:,self.idxs,self.attention_idx]/math.sqrt(self.attn_dim)
        selected_logits = (attend_logits+self.eye)/math.sqrt(self.attn_dim)


        selected_weights = selected_logits.softmax(dim=-1)
        weighted_keys = (selected_weights @ values).reshape(self.action_dims,self.batch_size,-1)

       #.reshape(self.ensemble_size,self.action_dims,batch_size,-1)

        x = torch.relu(self.fc2(weighted_keys))
        x = torch.relu(self.fc3(x))

       #x = x.unsqueeze(dim=1).repeat(1, self.action_dims, 1, 1)

        output = self.output_heads.forward(x).permute(0,2,1,3).squeeze(-1)
       #output = self.output_heads.forward(x)
        return output



class DiscreteValue(nn.Module):


    def __init__(self, state_dim, hidden_dim, action_dims, action_bins, ensemble_size, enc_dim=None):

        super(DiscreteValue, self).__init__()

        if enc_dim is None:
            self.encoder_dim = hidden_dim
        else:
            self.encoder_dim = enc_dim

        self.fc1 = VectorizedLinearHead(state_dim, self.encoder_dim, action_dims, ensemble_size)
        self.fc2 = VectorizedLinearHead(self.encoder_dim, hidden_dim, action_dims, ensemble_size)
        self.fc3 = VectorizedLinearHead(hidden_dim, hidden_dim, action_dims, ensemble_size)

       #self.fc1 = VectorizedLinear(state_dim, hidden_dim, ensemble_size)
       #self.fc2 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)
       #self.fc3 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)

       #self.output_heads = VectorizedLinear(hidden_dim, action_dims, ensemble_size)
        self.output_heads = VectorizedLinearHead(hidden_dim, 1, action_dims, ensemble_size)

        self.action_dims = action_dims
        self.ensemble_size = ensemble_size

    def forward(self, state):

        x = state



        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))

       #x = x.unsqueeze(dim=1).repeat(1, self.action_dims, 1, 1)
        output = self.output_heads.forward(x).permute(0,2,1,3)#.squeeze(-1)
       #output = self.output_heads(x)
        return output


class DiscreteCritic(nn.Module):


    def __init__(self, state_dim, hidden_dim, action_dims, action_bins, ensemble_size):

        super(DiscreteCritic, self).__init__()

       #self.fc1 = VectorizedLinear(state_dim, hidden_dim, ensemble_size)
       #self.fc1 = VectorizedLinear(state_dim+action_dims-1, hidden_dim, ensemble_size)
       #self.fc2 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)
       #self.fc3 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)

       #self.fc1 = VectorizedLinearHead(state_dim+action_dims-1, hidden_dim, action_dims, ensemble_size)
       #self.critic_input = VectorizedLinear(state_dim+action_dims, hidden_dim, ensemble_size)
       #self.fc1 = VectorizedLinearHead(state_dim, hidden_dim, action_dims, ensemble_size)
       #self.fc2 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)
       #self.fc3 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)
       #self.fc2 = VectorizedLinearHead(hidden_dim, hidden_dim, action_dims, ensemble_size)

        self.fc1 = VectorizedLinearHead(state_dim+action_dims-1, hidden_dim, action_dims, ensemble_size)
        self.fc2 = VectorizedLinearHead(hidden_dim, hidden_dim, action_dims, ensemble_size)
        self.fc3 = VectorizedLinearHead(hidden_dim, hidden_dim, action_dims, ensemble_size)

       #self.value_input  = VectorizedLinearHead(state_dim, hidden_dim, action_dims, ensemble_size)
       #self.value_input  = VectorizedLinear(state_dim, hidden_dim, ensemble_size)
       #self.v_fc2 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)
       #self.v_fc3 = VectorizedLinear(hidden_dim, hidden_dim, ensemble_size)



        self.output_heads = VectorizedLinearHead(hidden_dim, action_bins, action_dims, ensemble_size)
       #self.output_heads = VectorizedLinearHead(hidden_dim, 1, action_dims, ensemble_size)
       #self.output_heads = VectorizedLinear(hidden_dim, action_dims, ensemble_size)
       #self.value_heads = VectorizedLinear(hidden_dim, action_dims, ensemble_size)
       #self.advantage_heads = VectorizedLinear(hidden_dim, action_dims, ensemble_size)
       #self.advantage_heads = VectorizedLinearHead(hidden_dim, action_bins, action_dims, ensemble_size)
       #self.value_heads = VectorizedLinearHead(hidden_dim, 1, action_dims, ensemble_size)

        self.action_dims = action_dims
        self.ensemble_size = ensemble_size

    def forward(self, state, action, ret_advantage=False):

       #action = action.flatten(1,2)

       #print(state.shape,action.shape)
    #   x = state
        if self.action_dims == 1:
            x = state
        else:
           #x = state
           #x = torch.cat([state, action], dim=1)

            idxs = list(range(self.action_dims))
            action_idx = [ idxs[:i]+ idxs[i+1:] for i in range(self.action_dims)]
           #action_idx = [ [i] for i in range(self.action_dims)]
            x_2 = action[:,action_idx]
            x_1 = state.unsqueeze(1).repeat(1,self.action_dims,1)
            x = torch.cat([x_1, x_2], dim=2).permute(1,0,2)


       #batch_size = x.shape[1]
       #x = x.flatten(0,1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
       #x = x.reshape(self.ensemble_size,self.action_dims,batch_size,-1)

       #x = x.unsqueeze(dim=1).repeat(1, self.action_dims, 1, 1)

        output = self.output_heads.forward(x).permute(0,2,1,3).squeeze(-1)
       #output = self.output_heads.forward(x)
        return output

    #   x = torch.relu(self.fc1(x))
    #   x = torch.relu(self.fc2(x))
    #  #x = torch.relu(self.fc3(x))
    #   x = x.unsqueeze(dim=1).repeat(1, self.action_dims, 1, 1)
    #   advantage = self.advantage_heads.forward(x).permute(0,2,1,3)

    #   if ret_advantage:
    #       return advantage
    #   else:
    #       value = self.value_heads.forward(x).permute(0,2,1,3)

    #       q_value = value + (advantage - advantage.mean(dim=-1,keepdim=True))
    #       return q_value


