from torch import nn
import torch

try:
    from .base_model import BaseModel
    from fqf_iqn_qrdqn.network import DQNBase, NoisyLinear
except:
    from fqf_iqn_qrdqn.model.base_model import BaseModel
    from fqf_iqn_qrdqn.network import DQNBase, NoisyLinear


class PQRDQN(BaseModel):

    def __init__(self, num_channels, num_actions, N=200, embedding_dim=7*7*64,
                 dueling_net=False, noisy_net=False):
        super(PQRDQN, self).__init__()
        linear = NoisyLinear if noisy_net else nn.Linear

        # Feature extractor of DQN.
        self.dqn_net = DQNBase(num_channels=num_channels)
        # Quantile network.
        if not dueling_net:
            self.q_net = nn.Sequential(
                linear(embedding_dim, 512),
                nn.ReLU(),
                linear(512, num_actions * N),
            )
        else:
            self.advantage_net = nn.Sequential(
                linear(embedding_dim, 512),
                nn.ReLU(),
                linear(512, num_actions * N),
            )
            self.baseline_net = nn.Sequential(
                linear(embedding_dim, 512),
                nn.ReLU(),
                linear(512, N),
            )

        self.N = N
        self.num_channels = num_channels
        self.num_actions = num_actions
        self.embedding_dim = embedding_dim
        self.dueling_net = dueling_net
        self.noisy_net = noisy_net

    def forward(self, states=None, state_embeddings=None):
        assert states is not None or state_embeddings is not None
        batch_size = states.shape[0] if states is not None\
            else state_embeddings.shape[0]

        if state_embeddings is None:
            state_embeddings = self.dqn_net(states)

        if not self.dueling_net:
            quantiles = self.q_net(
                state_embeddings).view(batch_size, self.N, self.num_actions)
        else:
            advantages = self.advantage_net(
                state_embeddings).view(batch_size, self.N, self.num_actions)
            baselines = self.baseline_net(
                state_embeddings).view(batch_size, self.N, 1)
            quantiles = baselines + advantages\
                - advantages.mean(dim=2, keepdim=True)

        assert quantiles.shape == (batch_size, self.N, self.num_actions)

        return quantiles

    def calculate_q(self, states=None, state_embeddings=None):
        assert states is not None or state_embeddings is not None
        batch_size = states.shape[0] if states is not None\
            else state_embeddings.shape[0]

        # Calculate quantiles.
        quantiles = self(states=states, state_embeddings=state_embeddings)

        # Calculate expectations of value distributions.
        q = quantiles.mean(dim=1)
        assert q.shape == (batch_size, self.num_actions)

        return q
    
    def calculate_abs_q(self, states=None, state_embeddings=None):
        assert states is not None or state_embeddings is not None
        batch_size = states.shape[0] if states is not None\
            else state_embeddings.shape[0]

        # Calculate quantiles.
        quantiles = self(states=states, state_embeddings=state_embeddings)

        # Calculate expectations of value distributions.
        abs_q = torch.abs(quantiles).mean(dim=1)
        assert abs_q.shape == (batch_size, self.num_actions)

        return abs_q


    def calculate_std_mean(self, states=None, state_embeddings=None, coefficient_C=0):
        assert states is not None or state_embeddings is not None
        batch_size = states.shape[0] if states is not None\
            else state_embeddings.shape[0]

        # Calculate quantiles.
        quantiles = self(states=states, state_embeddings=state_embeddings)

        # Calculate standard deviations and expectations of value distributions.
        q = quantiles.mean(dim=1)
        std = quantiles.std(dim=1)
        assert q.shape == (batch_size, self.num_actions)

        return q + coefficient_C * std

    def calculate_P(self, states=None, state_embeddings=None, xi=None):
        assert states is not None or state_embeddings is not None
        batch_size = states.shape[0] if states is not None\
            else state_embeddings.shape[0]
  
        # Calculate quantiles.
        quantiles = self(states=states, state_embeddings=state_embeddings)
        if xi is None:
            xi = torch.ones(self.N)
        xi = xi.reshape(1,-1,1)
        # Calculate Perturbed Value 
        p = torch.mul(quantiles, xi).mean(dim=1) #perturbed_mean
        q = quantiles.mean(dim=1) #true mean

        assert p.shape == (batch_size, self.num_actions)
        assert q.shape == (batch_size, self.num_actions)

        return p
