#-*- coding:utf-8 -*-

import torch.nn as nn 
import torch

class Discriminator(nn.Module):
    def __init__(
        self, 
        input_length:int, 
        hidden_sizes=[128, 128], 
        activation=torch.tanh
    ):
        """
            inputs: [s, a]
                - s : (B, To, Do)
                - a : (B, Tp, Da)
            outputs: 1/0 
        """
        super().__init__()
        self._activation = activation
        self._layers = nn.ModuleList()
        self.input_length = input_length
        last_in_dim = input_length
        for out_dim in hidden_sizes:
            self._layers.append(nn.Linear(last_in_dim, out_dim))
            last_in_dim = out_dim

        self.logits = nn.Linear(last_in_dim, 1)
        self.logits.weight.data.mul_(0.1)
        self.logits.bias.data.mul_(0.0)
        
        self.loss_func = nn.BCELoss()

    def forward(self, a:torch.Tensor, s:torch.Tensor):
        """
            - a, action: (B, Tp, Da)
            - s, obs: (B, To, Do)
        """
        a = a.flatten(start_dim=1)
        s = s.flatten(start_dim=1)
        x = torch.cat([s, a], dim=1)
        for linear in self._layers:
            x = self._activation(linear(x))
        prob = torch.sigmoid(self.logits(x))
        return prob
    
    def gan_loss(self, a_generated:torch.Tensor, a_expert:torch.Tensor, s:torch.Tensor):
        y_fake = self.forward(a=a_generated, s=s)

        y_real = self.forward(a=a_expert, s=s)
        discrim_loss = self.loss_func(y_fake, torch.ones((s.shape[0], 1), device=s.device)) 
        discrim_loss += self.loss_func(y_real, torch.zeros((s.shape[0], 1), device=s.device))
        return discrim_loss

if __name__ == '__main__':
    To = 2
    Tp = 16
    Da = 2
    Do = 5
    B = 10
    a = torch.randn(B, Tp, Da)
    s = torch.randn(B, To, Do)
    num_inputs = To*Do + Tp*Da
    discriminator = Discriminator(num_inputs)
    o = discriminator(a, s)
    print(o.shape)
    
    loss = discriminator.gan_loss(a, a, s)
    print(loss)
