import torch
import torch.nn as nn

class AllPairsBlackBoxNet3C(nn.Module):
    def __init__(self, center_size=2, in_size=1000, girth=512):
        super().__init__()
        self._clen = center_size
        self.mlp = nn.Sequential(
            nn.Linear(in_size, girth),
            nn.ReLU(),
            nn.Linear(girth, girth),
            nn.ReLU(),
            # nn.Linear(girth, girth),
            # nn.ReLU(),
            nn.Linear(girth, self._clen + 1),
        )
        self.relation = nn.Sequential(
            # nn.ReLU(), ### TODO: check without this ReLU
            nn.Linear(self._clen + 1, girth),
            nn.ReLU(),
            nn.Linear(girth, girth),
            nn.ReLU(),
            nn.Linear(girth, 3)
        )

    def run_tests(self, z):
        """ z.shape: (B, clen + 1)
        """
        z = z.unsqueeze(1) # (B, 1, clen + 1)
        z = (z - z.permute(1, 0, 2)) # (B, B, clen + 1)
        tests = self.relation(z) # (B, B, 3)
        return tests

    def encode(self, x):
        """ x.shape: (B, D)
        """
        z = self.mlp(x) # (B, clen + 1)
        return z

    def forward(self, x):
        """ x.shape: (B, D)
        """
        z = self.encode(x) # (B, clen + 1)
        tests = self.run_tests(z) # (B, B, 3)
        return tests

    def predict(self, x, th=0.5):
        test = self.forward(x)
        test = test.argmax(dim=-1)
        return test

class AllPairsOrthotopeNet3C(nn.Module):
    # PReLU > ELU > ReLU, LeakyReLU (??)
    def __init__(self, dimension=2, in_size=1000, girth=512):
        super().__init__()
        self._dim = dimension
        self.mlp = nn.Sequential(
            nn.Linear(in_size, girth),
            nn.PReLU(),
            nn.Linear(girth, girth),
            nn.PReLU(),
            nn.Linear(girth, girth),
            nn.PReLU(),
            nn.Linear(girth, 2 * dimension),
        )

    def encode(self, x):
        """ x.shape: (B, D)
        """
        z = self.mlp(x) # (B, 2 * dim)
        zc = z[:, :self._dim].clone() # (B, dim)
        zr = z[:, self._dim:].abs() # (B, dim)
        z = torch.concat([zc, zr], dim=-1) # (B, 2 * dim)
        return z

    def run_tests(self, z):
        """ z.shape: (B, 2 * dim)
        """
        zc, zr = z[:, :self._dim], z[:, self._dim:]
        dist = zc.unsqueeze(1) - zc.unsqueeze(0) # (B, B, dim)
        # dist = torch.linalg.vector_norm(dist, ord=float("inf"), dim=-1) # (B, B)
        dist = dist.abs() # (B, B, dim)
        rsum = zr.unsqueeze(1) + zr.unsqueeze(0) # (B, B, dim)
        rdiff1 = -zr.unsqueeze(1) + zr.unsqueeze(0) # (B, B, dim)
        rdiff2 = -rdiff1 # (B, B, dim)
        test_neg = (dist - rsum).max(dim=2).values.unsqueeze(-1) # (B, B, 1)
        test_pos1 = (rdiff1 - dist).min(dim=2).values.unsqueeze(-1) # (B, B, 1)
        test_pos2 = (rdiff2 - dist).min(dim=2).values.unsqueeze(-1) # (B, B, 1)
        # test = torch.concat([test_neg, test_pos1, test_pos2], dim=-1) # (B, B, 3)
        test = torch.concat([0.0 * test_neg, test_pos1, test_pos2], dim=-1) # (B, B, 3)
        return test

    def forward(self, x):
        """ x.shape: (B, D)
        """
        z = self.encode(x) # (B, clen + 1)
        test = self.run_tests(z) # (B, B, 3)
        return test

    def predict(self, x):
        test = self.forward(x)
        test = test.argmax(dim=-1)
        return test

class AllPairsVennNet3C(nn.Module):
    def __init__(self, center_size=2, in_size=1000, girth=512):
        super().__init__()
        self._clen = center_size
        self.mlp = nn.Sequential(
            nn.Linear(in_size, girth),
            nn.PReLU(),
            nn.Linear(girth, girth),
            nn.PReLU(),
            # nn.Linear(girth, girth),
            # nn.ReLU(),
            nn.Linear(girth, girth),
            nn.PReLU(),
            nn.Linear(girth, self._clen + 1),
        )

    def encode(self, x):
        """ x.shape: (B, D)
        """
        z = self.mlp(x) # (B, clen + 1)
        zc = z[:, :-1].clone() # (B, clen)
        zr = z[:, -1:].abs() # (B, 1)
        z = torch.concat([zc, zr], dim=-1) # (B, clen + 1)
        return z

    def run_tests(self, z):
        """ z.shape: (B, clen + 1)
        """
        zc, zr = z[:, :-1], z[:, -1]
        dist = zc.unsqueeze(1) - zc.unsqueeze(0) # (B, B, clen)
        dist = torch.linalg.vector_norm(dist, dim=-1) # (B, B)
        # dist = torch.linalg.vector_norm(dist, ord=float("inf"), dim=-1) # (B, B)
        rsum = zr.unsqueeze(1) + zr.unsqueeze(0) # (B, B)
        rdiff1 = -zr.unsqueeze(1) + zr.unsqueeze(0) # (B, B)
        rdiff2 = -rdiff1 # (B, B)
        test_neg = (dist - rsum).unsqueeze(-1) # (B, B, 1)
        test_pos1 = (rdiff1 - dist).unsqueeze(-1) # (B, B, 1)
        test_pos2 = (rdiff2 - dist).unsqueeze(-1) # (B, B, 1)
        test = torch.concat([test_neg, test_pos1, test_pos2], dim=-1) # (B, B, 3)
        return test

    def forward(self, x):
        """ x.shape: (B, D)
        """
        z = self.encode(x) # (B, clen + 1)
        test = self.run_tests(z) # (B, B, 3)
        return test

    def predict(self, x):
        test = self.forward(x)
        test = test.argmax(dim=-1)
        return test

