from numpy import add
import torch
import torch.optim as optim

from torch import nn as nn
from torch.nn import functional as F


class Predict_Network_OS(nn.Module):

    def __init__(self, num_inputs,hidden_dim, obs_shape, state_shape, n_agents, lr=3e-4):
        super(Predict_Network_OS, self).__init__()

        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.last_fc_obs = nn.Linear(hidden_dim, obs_shape* n_agents)
        self.last_fc_state = nn.Linear(hidden_dim, state_shape)

    def forward(self, input):
        h = F.relu(self.linear1(input))
        h = F.relu(self.linear2(h))
        next_obs = self.last_fc_obs(h)
        next_state = self.last_fc_state(h)
        return next_obs, next_state

    # def get_log_pi(self, own_variable, other_variable):
    #     predict_variable = self.forward(own_variable)
    #     log_prob = -1 * F.mse_loss(predict_variable,
    #                                other_variable, reduction='none')
    #     log_prob = torch.sum(log_prob, -1, keepdim=True)
    #     return log_prob

    def update(self, own_variable, other_variable, mask):
        if mask.sum() > 0:
            predict_variable = self.forward(own_variable)
            loss = F.mse_loss(predict_variable,
                              other_variable, reduction='none')
            loss = loss.sum(dim=-1, keepdim=True)
            loss = (loss * mask).sum() / mask.sum()

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.parameters(), 1.)
            self.optimizer.step()

            return loss.to('cpu').detach().item()

        return None


class Predict_Network(nn.Module):

    def __init__(self, num_inputs,hidden_dim, output_shape , lr=3e-4):
        super(Predict_Network, self).__init__()

        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.last_fc = nn.Linear(hidden_dim, output_shape)
    

    def forward(self, input):
        h = F.relu(self.linear1(input))
        h = F.relu(self.linear2(h))
        x = self.last_fc(h)
        return x
    
    def get_log_pi(self, own_variable, other_variable):
        predict_variable = self.forward(own_variable)
        log_prob = -1 * F.mse_loss(predict_variable,
                                   other_variable, reduction='none')
        log_prob = torch.sum(log_prob, -1, keepdim=False)
  
        return log_prob

    def update(self, own_variable, other_variable, mask):
        if mask.sum() > 0:
            predict_variable = self.forward(own_variable)
            loss = F.mse_loss(predict_variable,
                              other_variable, reduction='none')
            loss = loss.sum(dim=-1, keepdim=False)  # bs,t , self.n_agents-1
            loss = loss.sum(dim=-1, keepdim=True) # bs,t , 1
            loss = (loss * mask).sum() / mask.sum()

            return loss

        return None

