import math
from itertools import count, permutations

import torch
import torch.nn as nn


def make_mlp(l, act=nn.LeakyReLU(), tail=[]):
    """makes an MLP with no top layer activation"""
    net = nn.Sequential(*(sum(
        [[nn.Linear(i, o)] + ([act] if n < len(l)-2 else [])
         for n, (i, o) in enumerate(zip(l, l[1:]))], []) + tail))
    return net

class NaivePermEquiNet(nn.Module):
    """
    Equivariant w.r.t. permutation on first ndim dimensions
    Invariant to the last dimension (i.e. stop action)
    Implemented with naive enumeration of all permutations
    """
    def __init__(self, horizon, ndim, n_hid, n_layers) -> None:
        super().__init__()
        self.model = make_mlp([horizon * ndim] + [n_hid] * n_layers + [ndim+1])
        self.ndim = ndim
        self.horizon = horizon
        self.index_ls = list(permutations(range(self.ndim)))
        gather_index_ = torch.cat([torch.tensor(self.index_ls),  # (ndim!, ndim)
                                ndim*torch.ones(math.factorial(ndim), 1).long()], dim=-1)  # (ndim!, ndim+1)
        self.gather_index = gather_index_.unsqueeze(dim=1)  # (ndim!, 1, ndim+1)

    def forward(self, inp):
        batch_size = inp.shape[0]
        inp = inp.reshape(batch_size, self.ndim, self.horizon).unsqueeze(0)  # -> (1, bs, ndim, horizon)
        aug_inp = torch.cat([
                inp[:, :, index, :] for index in self.index_ls
            ], dim=0)

        aug_output = self.model(aug_inp.reshape(
                math.factorial(self.ndim), batch_size, self.ndim*self.horizon)) # (ndim!, bs, ndim+1)
        gather_index = self.gather_index.expand_as(aug_output)  # (ndim!, 1, ndim+1) -> (ndim!, bs, ndim+1)
        aug_output = torch.scatter(torch.zeros_like(aug_output), -1, gather_index, aug_output)  # (ndim!, bs, ndim+1)
        return aug_output.mean(dim=0)

class NaivePermEquiNet_DB(nn.Module):
    def __init__(self, horizon, ndim, n_hid, n_layers) -> None:
        super().__init__()
        self.model = make_mlp([horizon * ndim] + [n_hid] * n_layers + [2*ndim+2])
        self.ndim = ndim
        self.horizon = horizon
        self.index_ls = list(permutations(range(self.ndim)))
        gather_index_ = torch.cat([torch.tensor(self.index_ls), 
                                   ndim*torch.ones(math.factorial(ndim), 1).long(),
                                   torch.tensor(self.index_ls),
                                   (2*ndim+1)*torch.ones(math.factorial(ndim), 1).long(),
                                ], dim=-1)  # (ndim!, 2ndim+2)
        self.gather_index = gather_index_.unsqueeze(dim=1)  # (ndim!, 1, 2ndim+2)

    def forward(self, inp):
        batch_size = inp.shape[0]
        inp = inp.reshape(batch_size, self.ndim, self.horizon).unsqueeze(0)  # -> (1, bs, ndim, horizon)
        aug_inp = torch.cat([
                inp[:, :, index, :] for index in self.index_ls
            ], dim=0)

        aug_output = self.model(aug_inp.reshape(
                math.factorial(self.ndim), batch_size, self.ndim*self.horizon)) # (ndim!, bs, 2ndim+2)
        gather_index = self.gather_index.expand_as(aug_output)  # (ndim!, 1, 2ndim+2) -> (ndim!, bs, 2ndim+2)
        aug_output = torch.scatter(torch.zeros_like(aug_output), -1, gather_index, aug_output)  # (ndim!, bs, ndim+1)
        return aug_output.mean(dim=0)

class NaivePermEquiNet_TB(nn.Module):
    def __init__(self, horizon, ndim, n_hid, n_layers) -> None:
        super().__init__()
        self.model = make_mlp([horizon * ndim] + [n_hid] * n_layers + [2*ndim+1])
        self.ndim = ndim
        self.horizon = horizon
        self.index_ls = list(permutations(range(self.ndim)))
        gather_index_ = torch.cat([torch.tensor(self.index_ls), 
                                   ndim*torch.ones(math.factorial(ndim), 1).long(),
                                   torch.tensor(self.index_ls),
                                ], dim=-1)  
        self.gather_index = gather_index_.unsqueeze(dim=1) 

    def forward(self, inp):
        batch_size = inp.shape[0]
        inp = inp.reshape(batch_size, self.ndim, self.horizon).unsqueeze(0)  # -> (1, bs, ndim, horizon)
        aug_inp = torch.cat([
                inp[:, :, index, :] for index in self.index_ls
            ], dim=0)

        aug_output = self.model(aug_inp.reshape(
                math.factorial(self.ndim), batch_size, self.ndim*self.horizon)) 
        gather_index = self.gather_index.expand_as(aug_output)  
        aug_output = torch.scatter(torch.zeros_like(aug_output), -1, gather_index, aug_output)  
        return aug_output.mean(dim=0)


class PermEquiNet(nn.Module):
    """
    Equivariant w.r.t. permutation on first ndim dimensions
    Invariant to the last dimension (i.e. stop action)
    NOTE: we still cannot guarantee same output on diagonal, but I will let it go
    """
    def __init__(self, horizon, ndim, n_hid, n_layers) -> None:
        super().__init__()
        self.model = make_mlp([horizon * ndim] + [n_hid] * n_layers + [ndim+1])
        self.ndim = ndim
        self.horizon = horizon
    
    def forward(self, inp):
        batch_size = inp.shape[0]
        device = inp.device

        inp = inp.reshape(batch_size, self.ndim, self.horizon)  # -> (bs, ndim, horizon)
        compact_inp = inp.argmax(dim=-1)  # (bs, ndim), value in [0, horizon-1]
        new_compact_inp, sort_index = torch.sort(compact_inp, dim=-1, stable=True)  # sort_index=(bs, ndim), value in [0, ndim-1]
        # assert torch.equal(new_compact_inp, torch.gather(compact_inp, 1, sort_index))

        # first permute to a "standard order"
        inp0 = torch.gather(inp, 1, sort_index.unsqueeze(-1).expand_as(inp))  # -> (bs, ndim, horizon)
        output0 = self.model(inp0.reshape(batch_size, self.ndim*self.horizon))  # (bs, ndim+1)
        sort_index2 = torch.cat([sort_index, 
                        self.ndim*torch.ones(batch_size, 1).long().to(device)], dim=-1)  # (bs, ndim+1)
        # then permute back
        return torch.scatter(torch.zeros_like(output0), -1, sort_index2, output0)  # (bs, ndim+1)

class PermEquiNet_DB(nn.Module):
    """
    Equivariant w.r.t. permutation on [0:ndim] dimensions and [ndim+1, 2*ndim] dimensions
    Invariant to the (ndim+1)-th and (2*ndim+2)-th dimension (i.e. stop action and edge-flow)
    """
    def __init__(self, horizon, ndim, n_hid, n_layers) -> None:
        super().__init__()
        self.model = make_mlp([horizon * ndim] + [n_hid] * n_layers + [2*ndim+2])
        self.ndim = ndim
        self.horizon = horizon
    
    def forward(self, inp):
        batch_size = inp.shape[0]
        device = inp.device 

        inp = inp.reshape(batch_size, self.ndim, self.horizon)  # -> (bs, ndim, horizon)
        compact_inp = inp.argmax(dim=-1)  # (bs, ndim), value in [0, horizon-1]
        new_compact_inp, sort_index = torch.sort(compact_inp, dim=-1, stable=True)  # sort_index=(bs, ndim), value in [0, ndim-1]
        # assert torch.equal(new_compact_inp, torch.gather(compact_inp, 1, sort_index))

        # first permute to a "standard order"
        inp0 = torch.gather(inp, 1, sort_index.unsqueeze(-1).expand_as(inp))  # -> (bs, ndim, horizon)
        output0 = self.model(inp0.reshape(batch_size, self.ndim*self.horizon))  # (bs, 2ndim+1)
        sort_index2 = torch.cat([sort_index, 
                                self.ndim*torch.ones(batch_size, 1).long().to(device),
                                sort_index,
                                (2*self.ndim+1)*torch.ones(batch_size, 1).long().to(device),
                                ], dim=-1)  # (bs, 2ndim+2)
        # then permute back
        return torch.scatter(torch.zeros_like(output0), -1, sort_index2, output0)  # (bs, 2ndim+2)

class PermEquiNet_TB(nn.Module):
    """
    Equivariant w.r.t. permutation on first ndim dimensions and last ndim dimensions
    Invariant to the (ndim+1)-th dimension (i.e. stop action)
    """
    def __init__(self, horizon, ndim, n_hid, n_layers) -> None:
        super().__init__()
        self.model = make_mlp([horizon * ndim] + [n_hid] * n_layers + [2*ndim+1])
        self.ndim = ndim
        self.horizon = horizon
    
    def forward(self, inp):
        batch_size = inp.shape[0]
        device = inp.device

        inp = inp.reshape(batch_size, self.ndim, self.horizon)  # -> (bs, ndim, horizon)
        compact_inp = inp.argmax(dim=-1)  # (bs, ndim), value in [0, horizon-1]
        new_compact_inp, sort_index = torch.sort(compact_inp, dim=-1, stable=True)  # sort_index=(bs, ndim), value in [0, ndim-1]
        # assert torch.equal(new_compact_inp, torch.gather(compact_inp, 1, sort_index))

        # first permute to a "standard order"
        inp0 = torch.gather(inp, 1, sort_index.unsqueeze(-1).expand_as(inp))  # -> (bs, ndim, horizon)
        output0 = self.model(inp0.reshape(batch_size, self.ndim*self.horizon))  # (bs, 2ndim+1)
        sort_index2 = torch.cat([sort_index, 
                                self.ndim*torch.ones(batch_size, 1).long().to(device),
                                sort_index], dim=-1)  # (bs, 2ndim+1)
        # then permute back
        return torch.scatter(torch.zeros_like(output0), -1, sort_index2, output0)  # (bs, 2ndim+1)


def unit_test_perminvnet():
    horizon, ndim, n_hid, n_layers = 4, 3, 256, 3
    # model = NaivePermEquiNet(horizon, ndim, n_hid, n_layers)
    model = PermEquiNet(horizon, ndim, n_hid, n_layers)
    inp = torch.tensor([
        [[0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0,]],
        [[0, 1, 0, 0,], [1, 0, 0, 0], [0, 0, 1, 0]],
        # [[0, 1, 0, 0,], [1, 0, 0, 0], [0, 0, 0, 1]]
    ]).float()   # (2, 3, 4)

    print(model(inp))
    print(model(inp[:, [1,0,2], :]))
    print(model(inp[:, [1,2,0], :]))
    quit()


if __name__ == "__main__":
    unit_test_perminvnet()