import torch
import torch.nn as nn
import numpy as np

from .base_model import BaseModel
from distributionalrl.network import DQNBase


class EUMNN(BaseModel):

    def __init__(self, num_channels, num_actions, K=32, embedding_dim=7*7*64, dueling_net=False, noisy_net=False):
        super(EUMNN, self).__init__()

        # Feature extractor of DQN.
        self.mean_net = nn.Sequential(DQNBase(num_channels=num_channels, embedding_dim=embedding_dim), nn.Linear(embedding_dim, num_actions*K)) # beta
        self.sigma_net = nn.Sequential(DQNBase(num_channels=num_channels, embedding_dim=embedding_dim), nn.Linear(embedding_dim, num_actions*K)) # alpha

        # UMNN.
        self.umnn = nn.Sequential(nn.Linear(1, 128), nn.ReLU(), nn.Linear(128, 1, bias=False))

        self.num_mixtures = K
        self.num_channels = num_channels
        self.num_actions = num_actions
        self.embedding_dim = embedding_dim
        self.dueling_net = dueling_net
        self.noisy_net = noisy_net

    def calculate_density_parameters(self, states, actions):
        means = torch.gather(self.mean_net(states).view(-1, self.num_actions, self.num_mixtures), 1, actions[:,:,None].broadcast_to(-1, 1, self.num_mixtures))[:,0,:]
        stds = torch.gather(self.sigma_net(states).view(-1, self.num_actions, self.num_mixtures), 1, actions[:,:,None].broadcast_to(-1, 1, self.num_mixtures))[:,0,:]
        
        proportions = torch.gather(self.proportion_net(states), 1, actions[:,:,None].broadcast_to(-1, 1, self.num_mixtures))[:,0,:]

        return proportions, means, stds


    def calculate_q(self, states=None):
        props, means = self.proportion_net(states), self.mean_net(states).view(-1, self.num_actions, self.num_mixtures)

        return torch.sum(props*means, dim=2)
    
    def params(self, prop_lr):
        return [{"params":self.mean_net.parameters()}, {"params":self.sigma_net.parameters()}, {"params":self.proportion_net.parameters(), "lr":prop_lr}]