import torch


class ZSQModel_United(torch.nn.Module):
    '''United Zero Sum Q-Model'''
    def __init__(self, pre_q_model, u_action_n, v_action_n):
        super().__init__()
        self.pre_q_model = pre_q_model
        self.u_action_n = u_action_n
        self.v_action_n = v_action_n
        return None
        
    def forward(self, states):
        q_values = self.pre_q_model(states)
        if q_values.shape[0] == self.u_action_n * self.v_action_n:
            q_values = q_values.reshape(self.u_action_n, self.v_action_n)
        else:
            q_values = q_values.reshape(states.shape[0], self.u_action_n, self.v_action_n)
        return q_values
