import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from math import gcd
from torch.distributions.categorical import Categorical


def one_hot_argmax(x, temp_back=0.1, temp_for=None):
    soft = F.softmax(x / temp_back, dim=1)
    if temp_for is None:
        argmax = torch.max(x, dim=1)[1]
        res = F.one_hot(argmax, num_classes=x.shape[1])
        return soft + (res - soft).detach()
    else:
        soft_for = F.softmax(x / temp_for, dim=1)
        return soft + (soft_for - soft).detach()


class ShiftScaleLayer(nn.Module):

    def __init__(self, dims, e, mask, device):
        super(ShiftScaleLayer, self).__init__()
        self.dims = dims
        self.temp_back = 0.1
        self.temp_for = None
        tot_dim = sum(self.dims)
        self.mask = mask.requires_grad_(False)
        self.layer = []
        self.shift_mat, self.scale_mat = [], []
        self.shift_mat_rev, self.scale_mat_rev = [], []
        self.scale_forbid = []
        self.e = e
        self.embed = torch.zeros(tot_dim, len(dims)*self.e)
        beg = 0
        for idx, k in enumerate(dims):
            self.embed[beg:beg+k, self.e*idx:self.e*(idx+1)] = torch.randn(k, self.e)
            beg += k
        self.embed = nn.Parameter(self.embed).requires_grad_(True)
            
        for idx, k in enumerate(dims):
            self.layer += [nn.Linear(self.e, 2*k)]
            self.shift_mat += [torch.zeros((k, k, k)).to(device).requires_grad_(False)]
            self.scale_mat += [torch.zeros((k, k, k)).to(device).requires_grad_(False)]
            self.shift_mat_rev += [torch.zeros((k, k, k)).to(device).requires_grad_(False)]
            self.scale_mat_rev += [torch.zeros((k, k, k)).to(device).requires_grad_(False)]
            for shift in range(k):
                for i in range(k):
                    self.shift_mat[idx][shift, (i - shift + k) % k, i] = 1
                    self.scale_mat[idx][shift, (i * shift) % k, i] = 1
                self.shift_mat_rev[idx][shift] = self.shift_mat[idx][shift].T
                self.scale_mat_rev[idx][shift] = self.scale_mat[idx][shift].T
            self.scale_forbid += [torch.zeros(k).to(device).requires_grad_(False)]
            for i in range(0, k):
                if gcd(i, k) > 1:
                    self.scale_forbid[idx][i] = -1e3

        self.layer = nn.ModuleList(self.layer)

    def shift(self, idx, x, coeffs, reverse=False):
        mat = self.shift_mat[idx]
        mat_rev = self.shift_mat_rev[idx]
        if reverse:
            out = torch.matmul(x, mat_rev)
        else:
            out = torch.matmul(x, mat)
        out = (out * coeffs.T.unsqueeze(2))
        out = out.sum(0)
        return out

    def scale(self, idx, x, coeffs, reverse=False):
        mat = self.scale_mat[idx]
        mat_rev = self.scale_mat_rev[idx]
        if reverse:
            out = torch.matmul(x, mat_rev)
        else:
            out = torch.matmul(x, mat)
        out = (out * coeffs.T.unsqueeze(2))
        out = out.sum(0)
        return out

    def forward(self, x):
        beg = 0
        x_ret = []
        masked_x = x * self.mask
        x_emb = torch.matmul(masked_x, self.embed)
        for idx, k in enumerate(self.dims):
            t = self.layer[idx](x_emb[:,idx*self.e:(idx+1)*self.e])
            shift, scale = torch.split(t, [k, k], dim=1)
            scale = scale + self.scale_forbid[idx]
            shift_one_hot = one_hot_argmax(shift, self.temp_back, self.temp_for)
            scale_one_hot = one_hot_argmax(scale, self.temp_back, self.temp_for)
            x_new = x[:,beg:beg+k]
            x_new = self.scale(idx, x_new, scale_one_hot)
            x_new = self.shift(idx, x_new, shift_one_hot)
            x_new = (x * self.mask)[:,beg:beg+k] + (1 - self.mask[beg:beg+k]) * x_new
            x_ret += [x_new]
            beg += k
        x_ret = torch.cat(x_ret, dim=1)
        return x_ret

    def inverse(self, x):
        beg = 0
        masked_x = x * self.mask
        x_emb = torch.matmul(masked_x, self.embed)
        x_ret = []
        for idx, k in enumerate(self.dims):
            t = self.layer[idx](x_emb[:,idx*self.e:(idx+1)*self.e])
            shift, scale = torch.split(t, [k, k], dim=1)
            scale = scale + self.scale_forbid[idx]
            shift_one_hot = one_hot_argmax(shift, self.temp_back, self.temp_for)
            scale_one_hot = one_hot_argmax(scale, self.temp_back, self.temp_for)
            x_new = x[:,beg:beg+k]
            x_new = self.shift(idx, x_new, shift_one_hot, reverse=True)
            x_new = self.scale(idx, x_new, scale_one_hot, reverse=True)
            x_new = (x * self.mask)[:,beg:beg+k] + (1 - self.mask[beg:beg+k]) * x_new
            x_ret += [x_new]
            beg += k
        x_ret = torch.cat(x_ret, dim=1)
        return x_ret


class DiscreteBipartiteFlow(nn.Module):

    def __init__(self, dims, e, num_blocks, device):
        super(DiscreteBipartiteFlow, self).__init__()
        self.dims = dims
        tot_dim = sum(self.dims)
        self.num_blocks = num_blocks
        masks = []
        for block_idx in range(num_blocks):
            masks.append([])
            for i, dim in enumerate(self.dims):
                masks[block_idx] += [(i + block_idx) % 2] * dim
            masks[block_idx] = torch.tensor(masks[block_idx]).to(device)
        self.blocks = []
        for i in range(num_blocks):
            self.blocks += [ShiftScaleLayer(dims, e, masks[i], device)]
        self.blocks = nn.Sequential(*self.blocks)
        self.prior = nn.Parameter(torch.randn((tot_dim)).to(device))

    def forward(self, x):
        z = x
        for block in self.blocks:
            z = block(z)
        beg = 0
        logp = 0
        for k in self.dims:
            logp_dim = z[:,beg:beg+k] * F.log_softmax(self.prior[beg:beg+k], dim=0)
            logp = logp + logp_dim.sum(1)
            beg += k
        return z, logp

    def inverse(self, x):
        for block in reversed(self.blocks):
            x = block.inverse(x)
        return x

    def update_temp(self, temp_back, temp_for):
        for block in self.blocks:
            block.temp_back = temp_back
            block.temp_for = temp_for

    def sample(self, n):
        beg = 0
        ret = []
        for k in self.dims:
            prob = Categorical(logits=self.prior[beg:beg+k])
            idx = prob.sample((n,))
            ret += [F.one_hot(idx, num_classes=k)]
            beg += k
        ret = torch.cat(ret, dim=1).float()
        ret = self.inverse(ret)
        return ret




        
                
