import torch
import numpy as np

from .base_model import BaseModel
from distributionalrl.network import DQNBase, ProportionNetwork


class MOGN(BaseModel):

    def __init__(self, state_shape, num_actions, K=32, embedding_dim=7*7*64, is_embedding=True, dueling_net=False, noisy_net=False, comparison_measure="wang", comparison_hyps=[0.0]):
        super(MOGN, self).__init__()

        # Feature extractor of DQN.
        self.dqn_net = None
        if(is_embedding):
            self.dqn_net = DQNBase(state_shape, embedding_dim=embedding_dim)

        # Gaussian parameters networks
        self.mean_net = torch.nn.Linear(embedding_dim, num_actions*K)
        self.sigma_net = torch.nn.Sequential(torch.nn.Linear(embedding_dim, num_actions*K), torch.nn.Softplus())

        # Proportion network.
        self.proportion_net = ProportionNetwork(num_actions=num_actions, embedding_dim=embedding_dim, num_mixtures=K)

        self.num_mixtures = K
        self.num_actions = num_actions
        self.embedding_dim = embedding_dim
        self.dueling_net = dueling_net
        self.noisy_net = noisy_net
        self.comparison_measure = comparison_measure
        self.comparison_hyps = comparison_hyps

    def calculate_density_parameters(self, states, actions=None):
        #print(states)
        if((self.dqn_net is None) == False):
            #print(states.shape)
            states = self.dqn_net(states)

        proportions, means, stds = None, None, None
        if(actions is None):
            proportions, means, stds = self.proportion_net(states), self.mean_net(states).view(-1, self.num_actions, self.num_mixtures), self.sigma_net(states).view(-1, self.num_actions, self.num_mixtures) + 1e-7
        else:
            means = torch.gather(self.mean_net(states).view(-1, self.num_actions, self.num_mixtures), 1, actions[:,:,None].broadcast_to(-1, 1, self.num_mixtures))[:,0,:]
            stds = torch.gather(self.sigma_net(states).view(-1, self.num_actions, self.num_mixtures), 1, actions[:,:,None].broadcast_to(-1, 1, self.num_mixtures))[:,0,:] + 1e-7 # it must be > 0 so that we can divide by stds
        
            proportions = torch.gather(self.proportion_net(states), 1, actions[:,:,None].broadcast_to(-1, 1, self.num_mixtures))[:,0,:]

        return proportions, means, stds
    
    def calculate_state_embeddings(self, states):
        return self.dqn_net(states)

    def calculate_q(self, states=None, density=None):
        
        if(density is None):
            if((self.dqn_net is None) == False):
                states = self.dqn_net(states)
            props, means, stds = self.proportion_net(states), self.mean_net(states).view(-1, self.num_actions, self.num_mixtures), self.sigma_net(states).view(-1, self.num_actions, self.num_mixtures) + 1e-7
        else:
            props, means, stds = density[0], density[1], density[2]

        if(self.comparison_measure == "quantiles"):
            norm = torch.distributions.normal.Normal(0., 1.)
            alpha = self.comparison_hyps[0]
            quantile = torch.sum(props*means, dim=len(means.shape)-1)
            for n in range(10):
                f_quantile = torch.sum(props*norm.cdf((quantile[...,None]-means)/stds), dim=len(means.shape)-1) - alpha
                df_quantile = torch.sum(props*(10**norm.log_prob((quantile[...,None]-means)/stds))/stds, dim=len(means.shape)-1)
                quantile = quantile - f_quantile/df_quantile
            
            return torch.nn.functional.softplus(quantile)
        
        elif(self.comparison_measure == "expectiles"):
            norm = torch.distributions.normal.Normal(0., 1.)
            alpha = self.comparison_hyps[0]
            expectile = torch.sum(props*means, dim=len(means.shape)-1)
            for n in range(10):
                f_expectile = torch.sum(props*((1-2*alpha)*(stds*(10**norm.log_prob((expectile[..., None]-means)/stds)) + (expectile[...,None]-means)*norm.cdf((expectile[...,None]-means)/stds)) + alpha*(expectile[...,None]-means)), dim=len(means.shape)-1)
                df_expectile = torch.sum(props*((1-2*alpha)*norm.cdf((expectile[..., None]-means)/stds) + alpha), dim=len(means.shape)-1)
                expectile = expectile - f_expectile/df_expectile
            
            return torch.nn.functional.softplus(expectile)
        
        elif(self.comparison_measure == "oneversusall"):
            norm = torch.distributions.normal.Normal(0., 1.)
            x, weights = np.polynomial.hermite.hermgauss(5)
            x, weights = torch.tensor(x).to(stds.device), torch.tensor(weights).to(stds.device)
            
            Fx = torch.sum(props[:, None, :, None, :, None]*norm.cdf((stds[:, :, None, :, None, None]*x[None, None, None, None, None, :] + means[:, :, None, :, None, None] - means[:, None, :, None, :, None])/stds[:, None, :, None, :, None]), axis=4) #sum_{K'} probas(N,A,A',K, K',Q)
            prod_f = torch.prod(Fx, axis=2)/torch.moveaxis(torch.diagonal(Fx, dim1=1, dim2=2), 3, 1) # (N, A, K, Q)

            fx = torch.sum(props[..., None]*prod_f/(stds[..., None]*np.sqrt(2*np.pi)), dim=2) # (N, A, Q)
            return torch.sum(weights[None, None, :]*fx, dim=2)

        elif(self.comparison_measure == "oneversusone"):
            norm = torch.distributions.normal.Normal(0., 1.)
            epsilon = self.comparison_hyps[0]
            probas = torch.sum(torch.sum(props[:,:,None,:,None]*props[:,None,:,None,:]*norm.cdf((means[:,:,None,:,None] - means[:,None,:,None,:])/torch.sqrt(stds[:,:,None,:,None]**2+stds[:,None,:,None,:]**2)), dim=2), dim=2)
            
            return torch.mean((probas >= 0.5 + epsilon) + 0.5*(torch.abs(probas - 0.5) < epsilon), dim=2)
        
        elif(self.comparison_measure == "softwang"):
            eta = self.comparison_hyps[0]

            return torch.nn.functional.softplus(torch.sum(props*(means+eta*stds), dim=2))
        
        else: # Wang risk measure
            eta = self.comparison_hyps[0] # if eta = 0, we get the expectation

            return torch.sum(props*(means+eta*stds), dim=len(means.shape)-1)
    
    def params(self, prop_lr):
        return [{"params":self.mean_net.parameters()}, {"params":self.sigma_net.parameters()}, {"params":self.dqn_net.parameters()}, {"params":self.proportion_net.parameters(), "lr":prop_lr}]
    
    def log_likelihood(self, states, actions, samples):
        props, means, stds = self.calculate_density_parameters(states, actions=actions)
        return torch.mean(torch.log(torch.sum(props*torch.exp(-(samples[:,None]-means)**2/(2*stds**2))/stds, dim=1)))