import numpy as np
import torch
import torch.nn as nn

class AllPairsBlackBoxNet(nn.Module):
    def __init__(self, hidden_size=512, in_size=1000, girth=512):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_size, girth),
            nn.ReLU(),
            nn.Linear(girth, girth),
            nn.ReLU(),
            # nn.Linear(girth, girth),
            # nn.ReLU(),
            nn.Linear(girth, hidden_size),
        )
        self.relation = nn.Sequential(
            # nn.ReLU(), ### TODO: check without this ReLU
            nn.Linear(hidden_size, girth),
            nn.ReLU(),
            nn.Linear(girth, girth),
            nn.ReLU(),
            nn.Linear(girth, 1)
        )

    def run_tests(self, z):
        """ z.shape: (B, clen + 1)
        """
        z = z.unsqueeze(1) # (B, 1, clen + 1)
        z = (z - z.permute(1, 0, 2)) # (B, B, clen + 1)
        tests = self.relation(z)[:, :, 0] # (B, B)
        return tests

    def encode(self, x):
        """ x.shape: (B, D)
        """
        z = self.mlp(x) # (B, clen + 1)
        return z

    def forward(self, x):
        """ x.shape: (B, D)
        """
        z = self.encode(x) # (B, clen + 1)
        tests = self.run_tests(z) # (B, B, 1)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def predict(self, x, th=0.5):
        test = self.forward(x) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsFastBlackBoxNet(nn.Module):
    def __init__(self, hidden_size=512, in_size=1000, girth=512):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_size, girth),
            nn.ReLU(),
            nn.Linear(girth, girth),
            nn.ReLU(),
            nn.Linear(girth, hidden_size),
            nn.ReLU(),
        )
        self.relation = nn.Sequential(
            nn.Linear(hidden_size, girth),
            nn.ReLU(),
            nn.Linear(girth, girth),
            nn.ReLU(),
            nn.Linear(girth, 1)
        )

    def run_tests(self, z):
        """ z.shape: (B, clen + 1)
        """
        z = z.unsqueeze(1) # (B, 1, clen + 1)
        z = (z - z.permute(1, 0, 2)) # (B, B, clen + 1)
        tests = self.relation(z)[:, :, 0] # (B, B)
        return tests

    def encode(self, x):
        """ x.shape: (B, D)
        """
        z = self.mlp(x) # (B, clen + 1)
        return z

    def _digit_norm(self, x):
        """ x.shape: (B, D)
        """
        x = (x - 5.0) / 10.0
        return x

    def forward(self, x):
        """ x.shape: (B, D)
        """
        x = self._digit_norm(x)
        z = self.encode(x) # (B, clen + 1)
        tests = self.run_tests(z) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def predict(self, x, th=0.5):
        test = self.forward(x) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsEvenFasterBlackBoxNet(nn.Module):
    def __init__(self, hidden_size=512, in_size=1000, girth=512):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_size, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, hidden_size),
            ### TODO: test removing the following
            nn.RMSNorm(hidden_size),
            nn.GELU(),
        )
        self.relation = nn.Sequential(
            nn.Linear(hidden_size, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, 1)
        )

    def run_tests(self, z):
        """ z.shape: (B, clen + 1)
        """
        z = z.unsqueeze(1) # (B, 1, clen + 1)
        z = (z - z.permute(1, 0, 2)) # (B, B, clen + 1)
        tests = self.relation(z)[:, :, 0] # (B, B)
        return tests

    def tgt_run_tests(self, z):
        """ z.shape: (B, clen + 1)
        """
        z = (z.clone()[-1:] - z) # (B, clen + 1)
        tests = self.relation(z)[:, 0] # (B,)
        return tests

    def encode(self, x):
        """ x.shape: (B, D)
        """
        x = self._digit_norm(x)
        z = self.mlp(x) # (B, clen + 1)
        return z

    def _digit_norm(self, x):
        """ x.shape: (B, D)
        """
        x = (x - 5.0) / 5.0
        return x

    def forward(self, x):
        """ x.shape: (B, D)
        """
        z = self.encode(x) # (B, clen + 1)
        tests = self.run_tests(z) # (B, B)
        tests = nn.functional.sigmoid(tests) # (B, B)
        return tests

    def tgt_forward(self, x):
        """ x.shape: (B, D)
            x[-1] is the only 'target' element
        """
        z = self.encode(x) # (B, clen + 1)
        tests = self.tgt_run_tests(z) # (B,)
        # tests = nn.functional.sigmoid(tests) # (B,)
        return tests

    def predict(self, x, th=0.5):
        test = self.forward(x) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class ResNormBlock(nn.Module):
    def __init__(
        self,
        in_size,
    ):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(in_size, in_size),
            nn.RMSNorm(in_size),
            nn.GELU(),
        )

    def forward(self, x):
        return x + self.block(x)

class AllPairsSingleGateMoE(nn.Module):
    def __init__(
        self,
        num_experts,
        in_size,
        hidden_size,
        girth,
        top_k=1,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.encoder = nn.Sequential(
            nn.Linear(in_size, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, hidden_size),
            nn.RMSNorm(hidden_size),
            nn.GELU(),
        )
        self.gate = nn.Sequential(
            nn.Linear(hidden_size, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, num_experts),
        )
        self.experts = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(hidden_size, girth),
                    nn.RMSNorm(girth),
                    nn.GELU(),
                    nn.Linear(girth, girth),
                    nn.RMSNorm(girth),
                    nn.GELU(),
                    nn.Linear(girth, 1)
                )
                for _ in range(num_experts)
            ]
        )

    def _digit_norm(self, x):
        """ x.shape: (B, D)
        """
        x = (x - 5.0) / 10.0
        return x

    def forward(self, x):
        """ x.shape: (B, D)
        """
        B = x.shape[0]
        x = self._digit_norm(x) # (B, D)
        x = self.encoder(x) # (B, H)
        x = x.unsqueeze(1) # (B, 1, H)
        x = (x - x.permute(1, 0, 2)) # (B, B, H)
        scores = nn.functional.softmax(self.gate(x), dim=-1) # (B, B, X)
        preds = torch.concat([expert(x) for expert in self.experts], dim=-1) # (B, B, X)
        preds = preds * scores # (B, B)
        preds = preds.sum(dim=-1) # (B, B)
        preds = nn.functional.sigmoid(preds) # (B, B)
        return preds

    def predict(self, x, th=0.5):
        """ x.shape: (B, D)
        """
        test = self.forward(x) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

class AllPairsVennNet(nn.Module):
    def __init__(
            self,
            hidden_size=512,
            in_size=1000,
            girth=512,
            num_digits=21,
            dbase=10.0,
            sigmoid_temp=1.0,
            discount_factor=1.0,
            max_radius=1e2,
            device="cpu",
        ):
        super().__init__()
        _in_size = num_digits * (in_size + 0)
        self.mlp = nn.Sequential(
            nn.Linear(_in_size, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, girth),
            nn.RMSNorm(girth),
            nn.GELU(),
            nn.Linear(girth, hidden_size),
            # nn.ReLU(),
            # nn.RMSNorm(hidden_size),
            # nn.GELU(),
        )
        assert num_digits > 0
        self.num_digits = num_digits
        pos_digits = (num_digits // 2) + (num_digits % 2)
        neg_digits = num_digits // 2
        self._dbase = dbase
        self._fdt = (self._dbase ** torch.arange(-neg_digits, pos_digits))
        self._fdt = self._fdt.reshape(1, -1)
        self._fdt = self._fdt.float().to(device)
        self._sigmoid_temp = sigmoid_temp
        self._discount_factor = discount_factor
        self._max_radius = max_radius

    def _to_digits(self, x):
        """ x.shape: (B, D)
        """
        B, D = x.shape
        x = x.reshape(-1, 1) @ self._fdt
        x = x.sign() * (x.abs() % self._dbase)
        x = x.reshape(B, D * self.num_digits) # (B, D * d)
        return x

    def _digit_norm(self, x):
        """ x.shape: (B, D)
        """
        x = x / self._dbase
        return x

    def encode(self, x):
        """ x.shape: (B, D)
        """
        x = self._to_digits(x)
        x = self._digit_norm(x)
        z = self.mlp(x) # (B, hidden_size)
        # zc = z[:, :-1].clone()
        # zr = z[:, -1:].abs()
        # zr = torch.clamp(zr, max=self._max_radius)
        # z = torch.concat([zc, zr], dim=-1) # (B, hidden_size)
        return z

    # def run_tests(self, z):
    #     """ z.shape: (B, clen + 1)
    #     """
    #     zc, zr = z[:, :-1], z[:, -1]
    #     dist = zc.unsqueeze(1) - zc.unsqueeze(0) # (B, B, clen)
    #     dist = torch.linalg.vector_norm(dist, dim=-1) # (B, B)
    #     rdiff = -zr.unsqueeze(1) + zr.unsqueeze(0) # (B, B)
    #     test = (rdiff - self._discount_factor * dist) # (B, B)
    #     return test

    # def run_tests(self, z):
    #     """ z.shape: (B, D)
    #     """
    #     z_area = z.mean(dim=-1) # (B,)
    #     z_rep = z.unsqueeze(0).repeat(z.shape[0], 1, 1) # (B, B, D)
    #     z_rep = z_rep.unsqueeze(-1) # (B, B, D, 1)
    #     z_cat = torch.concat([z_rep, z_rep.permute(1, 0, 2, 3)], dim=-1) # (B, B, D, 2)
    #     z_min = z_cat.min(dim=-1).values # (B, B, D)
    #     z_inter = z_min.mean(dim=-1) # (B, B)
    #     test = (z_inter - z_area) # (B, B)
    #     return test

    def run_tests(self, z):
        """ z.shape: (B, D)
        """
        z = nn.functional.sigmoid(z) # (B, D)
        z_a = z.unsqueeze(0)
        z_b = z.unsqueeze(1)
        z_prec = (1.0 - z_a) + z_a * z_b # (B, B, D)
        test = z_prec.mean(dim=-1) # (B, B)
        return test

    #######

    # def tgt_run_tests(self, z):
    #     """ z.shape: (B, clen + 1)
    #     """
    #     z = (z.clone()[-1:] - z) # (B, clen + 1)
    #     tests = self.relation(z)[:, 0] # (B,)
    #     return tests

    def tgt_run_tests(self, z):
        """ z.shape: (B, clen + 1)
        """
        zdiff = (z.clone()[-1:] - z) # (B, clen + 1)
        dist = torch.linalg.vector_norm(zdiff[:, :-1], dim=-1) # (B,)
        rdiff = zdiff[:, -1] # (B,)
        tests = (rdiff - self._discount_factor * dist) # (B,)
        return tests

    def forward(self, x):
        """ x.shape: (B, D)
        """
        z = self.encode(x) # (B, clen + 1)
        tests = self.run_tests(z) # (B, B)
        # tests = nn.functional.sigmoid(self._sigmoid_temp * tests) # (B, B)
        # tests = (self._sigmoid_temp * tests).exp()
        return tests

    def tgt_forward(self, x):
        """ x.shape: (B, D)
            x[-1] is the only 'target' element
        """
        z = self.encode(x) # (B, clen + 1)
        tests = self.tgt_run_tests(z) # (B,)
        return tests

    def predict(self, x, th=0.5):
        test = self.forward(x) # (B, B)
        test = (test >= th).long() # (B, B)
        return test

    #######

    # def forward(self, x):
    #     """ x.shape: (B, D)
    #     """
    #     z = self.encode(x) # (B, clen + 1)
    #     test = self.run_tests(z) # (B, B)
    #     ####### WARNING: pure sigmoid is too stringent
    #     test = nn.functional.sigmoid(test) # (B, B)
    #     return test

    # def predict(self, x, th=0.5):
    #     test = self.forward(x) # (B, B)
    #     test = (test >= th).long() # (B, B)
    #     return test

class AllPairsOrthotopeNet(nn.Module):
    # PReLU > ELU > ReLU, LeakyReLU (??)
    def __init__(self, dimension=2, in_size=1000, girth=512):
        super().__init__()
        self._dim = dimension
        self.mlp = nn.Sequential(
            nn.Linear(in_size, girth),
            nn.PReLU(),
            nn.Linear(girth, girth),
            nn.PReLU(),
            nn.Linear(girth, girth),
            nn.PReLU(),
            # nn.Linear(girth, girth),
            # nn.PReLU(),
            nn.Linear(girth, 2 * dimension),
        )

    def encode(self, x):
        """ x.shape: (B, D)
        """
        z = self.mlp(x) # (B, 2 * dim)
        zc = z[:, :self._dim].clone() # (B, dim)
        zr = z[:, self._dim:].abs() # (B, dim)
        z = torch.concat([zc, zr], dim=-1) # (B, 2 * dim)
        return z

    def run_tests(self, z):
        """ z.shape: (B, 2 * dim)
        """
        zc, zr = z[:, :self._dim], z[:, self._dim:]
        ###
        zr = torch.clamp(zr, max=100.0)
        ###
        dist = zc.unsqueeze(1) - zc.unsqueeze(0) # (B, B, dim)
        dist = dist.abs() # (B, B, dim)
        rdiff = -zr.unsqueeze(1) + zr.unsqueeze(0) # (B, B, dim)
        test = (rdiff - dist).min(dim=2).values # (B, B)
        return test

    def forward(self, x):
        """ x.shape: (B, D)
        """
        z = self.encode(x) # (B, 2 * dim)
        test = self.run_tests(z) # (B, B)
        test = nn.functional.sigmoid(test) # (B, B)
        return test

    def predict(self, x, th=0.5):
        test = self.forward(x) # (B, B)
        test = (test >= th).long() # (B, B)
        return test
