import torch

from .base_model import BaseModel
from distributionalrl.network import DQNBase, CosineEmbeddingNetwork,\
    QuantileNetwork


class IQN(BaseModel):

    def __init__(self, state_shape, num_actions, alea_comparison_measure="wang", alea_comparison_hyps=(0.0), K=32, num_cosines=32,
                 embedding_dim=7*7*64, dueling_net=False, noisy_net=False):
        super(IQN, self).__init__()

        # Feature extractor of DQN.
        self.dqn_net = DQNBase(state_shape)
        # Cosine embedding network.
        self.cosine_net = CosineEmbeddingNetwork(
            num_cosines=num_cosines, embedding_dim=embedding_dim,
            noisy_net=noisy_net)
        # Quantile network.
        self.quantile_net = QuantileNetwork(
            num_actions=num_actions, dueling_net=dueling_net,
            noisy_net=noisy_net)

        self.alea_comparison_measure = alea_comparison_measure
        self.alea_comparison_hyps = alea_comparison_hyps
        self.K = K
        self.state_shape = state_shape
        self.num_actions = num_actions
        self.num_cosines = num_cosines
        self.embedding_dim = embedding_dim
        self.dueling_net = dueling_net
        self.noisy_net = noisy_net

    def calculate_state_embeddings(self, states):
        new_states = self.dqn_net(states)
        return new_states

    def calculate_quantiles(self, taus, states=None, state_embeddings=None):
        assert states is not None or state_embeddings is not None

        if state_embeddings is None:
            state_embeddings = self.dqn_net(states)

        tau_embeddings = self.cosine_net(taus)
        return self.quantile_net(state_embeddings, tau_embeddings)

    def calculate_q(self, states=None, state_embeddings=None):
        assert states is not None or state_embeddings is not None
        batch_size = states.shape[0] if states is not None\
            else state_embeddings.shape[0]

        if state_embeddings is None:
            state_embeddings = self.dqn_net(states)

        # Sample fractions.
        taus = torch.rand(
            batch_size, self.K, dtype=state_embeddings.dtype,
            device=state_embeddings.device)
        
        # Apply distortion (risk-aware policy)
        if(self.alea_comparison_measure == "wang"):
            normal_dist = torch.distributions.normal.Normal(0., 1.)
            taus = normal_dist.cdf(normal_dist.icdf(taus)-self.alea_comparison_hyps[0])
        elif(self.alea_comparison_measure == "pow"):
            taus = 1-(1-taus)**(1/(1+self.alea_comparison_hyps[0]))
        elif(self.alea_comparison_measure == "cvar"):
            taus = self.alea_comparison_hyps[0]*taus
        elif(self.alea_comparison_measure == "oneversusall"):
            taus, _ = torch.sort(taus, dim=1)
            quantiles = self.calculate_quantiles(taus, state_embeddings=state_embeddings)
            fx = torch.sum(taus[:,None,:-1,None,None]*(quantiles[:,None,:-1,None,:] <= quantiles[:,:,None,:,None])*(quantiles[:,:,None,:,None] < quantiles[:,None,1:,None,:]), dim=2) + taus[:,:,None,None]*(quantiles[:,-1,None,:] <= quantiles[:,:,:,None])
            return torch.mean(torch.prod(fx, dim=3)/taus[:,:,None], dim=1)
        elif(self.alea_comparison_measure == "oneversusone"):
            epsilon = self.alea_comparison_hyps[0]
            taus, _ = torch.sort(taus, dim=1)
            quantiles = self.calculate_quantiles(taus, state_embeddings=state_embeddings)
            fx = torch.sum(taus[:,None,:-1,None,None]*(quantiles[:,None,:-1,None,:] <= quantiles[:,:,None,:,None])*(quantiles[:,:,None,:,None] < quantiles[:,None,1:,None,:]), dim=2) + taus[:,:,None,None]*(quantiles[:,-1,None,:] <= quantiles[:,:,:,None])
            fx = torch.mean(fx, dim=1)
            return torch.sum((fx >= 0.5+epsilon)*1.0+0.5*(torch.abs(fx-0.5) < epsilon), dim=2)


        # Calculate quantiles.
        quantiles = self.calculate_quantiles(
            taus, state_embeddings=state_embeddings)
        assert quantiles.shape == (batch_size, self.K, self.num_actions)

        # Calculate expectations of value distributions.
        q = quantiles.mean(dim=1)
        assert q.shape == (batch_size, self.num_actions)

        return q
    
    def params(self):
        return [{"params":self.parameters()}]
