import matplotlib.pyplot as plt
import torch
import sys

from torch.functional import F


class classifier(nn.Module):
    def __init__(self, num_gens, inp_dim):
        super().__init__()
        self.clf = nn.Sequential(*[ nn.Linear(inp_dim, 4 ), \
                                nn.Linear(4, num_gens) ])
    def forward(self, x):
        return self.clf(x)

class player(nn.Module):
    def __init__(self, k):
        super(player, self).__init__()
        self.strategy = nn.Parameter( torch.rand(k, requires_grad=True) )
    def forward(self, ):
        return self.strategy
