import torch
import numpy as np
import networkx as nx
from torch import Tensor, DoubleTensor, LongTensor
import matplotlib.pyplot as plt
from torch.nn import Module, Conv2d, Linear, BatchNorm2d, ReLU
from itertools import combinations, product
from torch.types import Tuple
from torch_scatter import scatter_sum
from SIP.args import cmd_args
from numba import jit
import math


class Ising(object):
    def __init__(self, device=torch.device("cpu"), p=100, mu=2.0, sigma=3.0, lamda=1.0, seed=0):
        self.device = device
        self.rng = np.random.default_rng(seed)
        self.info = f"ising_p-{p}_mu-{mu}_sigma-{sigma}_lamda-{lamda}_seed-{seed}_{self.device}"
        nx_g = nx.grid_2d_graph(p, p)
        nx.set_node_attributes(nx_g, {n: {'weight': (self.rng.uniform(- sigma, sigma)) + (
            self._weight(n, p, mu))} for n in nx_g.nodes()})
        nx.set_edge_attributes(nx_g, {e: {'weight': - lamda} for e in nx_g.edges()})
        nx_g = nx.convert_node_labels_to_integers(nx_g)
        nx_g = nx.to_directed(nx_g)
        self.row = LongTensor([e[0] for e in nx_g.edges()]).to(self.device)
        self.col = LongTensor([e[1] for e in nx_g.edges()]).to(self.device)
        self.n_w = DoubleTensor([nx_g.nodes[n]['weight'] for n in nx_g.nodes()]).to(self.device)
        self.e_w = DoubleTensor([nx_g.edges[e]['weight'] for e in nx_g.edges()]).to(self.device)
        self._num_nodes = nx_g.number_of_nodes()
        self.nx_g = nx_g

    def _weight(self, n, p, mu):
        if (n[0] / p - 0.5) ** 2 + (n[1] / p - 0.5) ** 2 < 0.5 / np.pi:
            return - mu
        else:
            return mu

    def init_state(self):
        return torch.ones(self.num_nodes).to(self.device)

    def energy(self, x : Tensor):
        if len(x.shape) == 1:
            return (x[self.row] * x[self.col] * self.e_w).sum() / 2 + (x * self.n_w).sum()
        elif len(x.shape) == 2:
            return (x[:, self.row] * x[:, self.col] * self.e_w).sum(dim=-1) / 2 + (x * self.n_w).sum(dim=-1)
        else:
            raise NotImplementedError("x type not supported")

    def grad(self, x : Tensor):
        if len(x.shape) == 1:
            grad = scatter_sum(x[self.row] * self.e_w, self.col) + (self.n_w)
        elif len(x.shape) == 2:
            grad = scatter_sum(x[:, self.row] * self.e_w, self.col, dim=1) + (self.n_w)
        else:
            raise NotImplementedError("x type not supported")
        return grad

    def energy_grad(self, x : Tensor):
        energy = self.energy(x)
        if len(x.shape) == 2:
            energy = energy.unsqueeze(-1)
        grad = self.grad(x)
        return energy, grad

    def delta(self, x : Tensor):
        return - 2 * x

    def change(self, x : Tensor):
        grad = self.grad(x)
        delta = self.delta(x)
        res = delta * grad
        return res

    def flip_state(self, x : Tensor, idx : Tuple, *args, **kwargs):
        z = x.clone()
        for i in idx:
            z[i] *= - 1
        return z

    def trace(self, x : Tensor):
        return x.sum().cpu().item()

    @property
    def num_nodes(self):
        return self._num_nodes


class FHMM(object):
    def __init__(self, L=1000, K=10, sigma2=0.5, alpha=0.1, beta=0.95, seed=0, device=torch.device("cpu")):
        self.L = L
        self.K = K
        self.device = device
        self.sigma = math.sqrt(sigma2)
        self.alpha = torch.FloatTensor([alpha]).to(device)
        self.beta = torch.FloatTensor([beta]).to(device)
        self.W = torch.randn((K, 1)).to(device)
        self.b = torch.randn((1, 1)).to(device)
        self.X = self.sample_X(seed)
        self.Y = self.sample_Y(self.X, seed)
        self.info = f"fhmm_L-{L}_K-{K}"
        self._num_nodes= self.L * self.K
        self.P_X0 = torch.distributions.Bernoulli(probs=self.alpha)
        self.P_XC = torch.distributions.Bernoulli(logits=1 - self.beta)


    def sample_X(self, seed):
        torch.manual_seed(seed)
        X = torch.ones((self.L, self.K)).to(self.device)
        X[0] = torch.bernoulli(X[0] * self.alpha)
        for l in range(1, self.L):
            p = self.beta * X[l - 1] + (1 - self.beta) * (1 - X[l - 1])
            X[l] = torch.bernoulli(p)
        return X

    def sample_Y(self, X, seed):
        torch.manual_seed(seed)
        return torch.randn((self.L, 1)).to(self.device) * self.sigma + X @ self.W + self.b

    def trace(self, x):
        return torch.log((x.view(self.L, self.K) @ self.W + self.b - self.Y).square().mean()).item()
        # return torch.log((x.view(-1, self.L, self.K) - self.X).square().sum(dim=1).mean()).item() / 2

    def init_state(self):
        return torch.ones(self.L * self.K).to(self.device)
        # return self.X.view(self.L * self.K)

    def energy(self, input):
        x = input.view(-1, self.L, self.K)
        x_0 = x[:, 0, :]
        x_cur = x[:, :-1, :]
        x_next = x[:, 1:, :]
        x_c = x_cur * (1 - x_next) + (1 - x_cur) * x_next
        energy_x = self.P_X0.log_prob(x_0).sum(-1) + self.P_XC.log_prob(x_c).sum(dim=[1, 2])
        energy_y = (self.Y - x @ self.W - self.b).square().sum(dim=[1,2]) / (2 * self.sigma ** 2)
        # energy_y = (self.Y - x @ self.W - self.b)[self.mask].square().sum() / (2 * self.sigma ** 2) / self.ratio
        return energy_x + energy_y

    def grad(self, input):
        z = input.view(self.L, self.K).clone().requires_grad_()
        energy = self.energy(z)
        grad = torch.autograd.grad(energy.sum(), z)[0].detach()
        return grad.view(self.L * self.K)

    def energy_grad(self, input):
        x = input.view(-1, self.L, self.K).clone().requires_grad_()
        energy = self.energy(x)
        grad = torch.autograd.grad(energy.sum(), x)[0].detach()
        return energy.detach(), grad

    def flip_state(self, x, idx):
        z = x.clone()
        for i in idx:
            z[i] = 1 - z[i]
        return z

    def delta(self, x):
        return 1 - 2 * x

    @property
    def num_nodes(self):
        return self._num_nodes


class TSP(object):
    def __init__(self, p=100, sigma=1.0, seed=0, device=torch.device("cpu")):
        torch.manual_seed(seed)
        self.device = device
        self.w = sigma * torch.randn((p, p)).to(device)
        self._num_nodes = p
        self.info = f"tsp_p-{p}_sigma-{sigma}"
        self.i1 = LongTensor([j for i in range(p - 1) for j in range(i + 1)]).to(device)
        self.i2 = LongTensor([i for i in range(1, p) for _ in range(1, i + 1)]).to(device)

    def energy(self, x):
        if len(x.shape) == 1:
            idx1 = x
            idx2 = torch.cat([x[1:], x[-1:]], dim=0)
            res = self.w[idx1, idx2].sum()
        else:
            Res = []
            for j in range(x.shape[0]):
                idx1 = x[j]
                idx2 = torch.cat([x[j, 1:], x[j, -1:]], dim=0)
                Res.append(self.w[idx1, idx2].sum())
            res = torch.stack(Res, dim=0)
        return res

    def change(self, x):
        i, j = x[self.i1], x[self.i2]
        i_m, j_m = x[(self.i1 - 1) % self._num_nodes], x[(self.i2 - 1) % self._num_nodes]
        i_p, j_p = x[(self.i1 + 1) % self._num_nodes], x[(self.i2 + 1) % self._num_nodes]
        res = self.w[i_m, j] + self.w[j, i_p] + self.w[j_m, i] + self.w[i, j_p] \
            - self.w[i_m, i] - self.w[i, i_p] - self.w[j_m, j] - self.w[j, j_p]
        return res

    def flip_state(self, x, idx):
        z = x.clone()
        i, j = self.i1[idx], self.i2[idx]
        z[i], z[j] = z[j], z[i]
        # for id in idx:
        #     i, j = self.i1[id], self.i2[id]
        #     z[i], z[j] = z[j], z[i]
        return z

    def init_state(self):
        return torch.arange(self._num_nodes, dtype=torch.int64).to(self.device)

    def trace(self, x):
        return 0

    @property
    def num_nodes(self):
        return int(self._num_nodes * (self._num_nodes - 1) / 2)


class BMM(object):
    def __init__(self, p=100, m=10, seed=0, device=torch.device("cpu")):
        """
        theta: (p, m) Tensor
        normlizer: (1, m) Tensor
        """
        self.rng = np.random.default_rng(seed)
        torch.manual_seed(seed)
        self.device = device
        self.info = f"bmm_p-{p}_m-{m}_seed-{seed}_{device}"
        theta = torch.ones((p, m))
        self.k = p // m
        for i in range(m):
            theta[i * self.k: (i+1)*self.k, i] *= -1
        theta += torch.randn_like(theta) * 0.1
        self.theta = theta.to(self.device)
        self.normalizer = torch.log(1 + torch.exp(-theta)).sum(dim=0, keepdim=True).to(self.device)
        self._num_nodes = p
        # self.debug = None
        # self.count = 0

    def init_state(self):
        return torch.ones(self.num_nodes).to(self.device)

    def energy(self, x : Tensor):
        if len(x.shape) == 1:
            x = x.unsqueeze(-1)
        elif len(x.shape) == 2:
            x = x.T
        logits = - self.theta.T @ x - self.normalizer.T
        res = - torch.logsumexp(logits, dim=0)
        # if self.debug is None:
        #     self.debug = torch.argmax(logits)
        #     # print(self.debug)
        # elif self.debug != torch.argmax(logits):
        #     self.debug = torch.argmax(logits)
        #     self.count += 1
            # print(self.debug, self.count)
        return res

    def grad(self, x : Tensor):
        if len(x.shape) == 1:
            x = x.unsqueeze(-1)
        elif len(x.shape) == 2:
            x = x.T
        logits = - self.theta.T @ x - self.normalizer.T
        res = self.theta @ torch.softmax(logits, dim=0)
        return res.squeeze()

    def energy_grad(self, x : Tensor):
        return self.energy(x), self.grad(x)

    def change(self, x : Tensor):
        if len(x.shape) == 1:
            x = x.unsqueeze(-1)
        logits = - self.theta.T @ x - self.normalizer.T
        logits_change = - self.theta.T * (1 - 2 * x.T)
        new_energy = - torch.logsumexp(logits + logits_change, dim=0)
        energy = - torch.logsumexp(logits, dim=0)
        res = new_energy - energy
        return res


    def flip_state(self, x : Tensor, idx : Tuple):
        z = x.clone()
        for i in idx:
            z[i] = 1 - z[i]
        return z

    def delta(self, x : Tensor):
        return 1 - 2 * x

    def trace(self, x : Tensor):
        # return x[:self.k].sum()
        assert len(x.shape) == 1
        x = x.unsqueeze(-1)
        logits = - self.theta.T @ x - self.normalizer.T
        return torch.argmax(logits).item()
        # return (logits[0] - torch.logsumexp(logits, dim=0)).item()
        # assert len(x.shape) == 1
        # x = x.unsqueeze(-1)
        # logits = - self.theta.T @ x - self.normalizer.T
        # prob = torch.exp(logits)
        # return (prob * torch.log(prob)).sum().item()
        # return (logits[0] - torch.logsumexp(logits, dim=0)).item()


    @property
    def num_nodes(self):
        return self._num_nodes


class Parity(object):
    def __init__(self, p=100, U=1, seed=0, device=torch.device("cpu")):
        self.rng = np.random.default_rng(seed)
        self._num_nodes = p
        self.U = U
        self.device = device
        self.info = f"parity_p-{p}_U-{U}"

    def init_state(self):
        return torch.ones(self.num_nodes).to(self.device)

    def energy(self, x : Tensor):
        res = x.sum(dim=-1) % 2
        return res * self.U

    def change(self, x : Tensor):
        return torch.ones_like(x)

    def flip_state(self, x : Tensor, idx : Tuple):
        z = x.clone()
        for i in idx:
            z[i] = 1 - z[i]
        return z

    def trace(self, x : Tensor):
        return x

    @property
    def num_nodes(self):
        return self._num_nodes


# class SK(object):
#     def __init__(self, p=10, seed=0):
#         self.p = p
#         self.device = torch.device(f"cuda:{cmd_args.device}" if torch.cuda.is_available() else "cpu")
#         self.rng = np.random.RandomState(seed)
#         self.A = torch.zeros((p, p), dtype=torch.float64)
#         self.A = torch.tril(torch.randn(p, p, dtype=torch.float64) / np.sqrt(2 * p), -1)
#         self.A += self.A.T
#         self.A = self.A.to(self.device)
#         self.h = torch.ones(p).to(self.device)
#
#     def init_state(self):
#         return torch.ones(self.p, dtype=torch.float64).to(self.device)
#         # return DoubleTensor(2 * self.rng.binomial(1, 0.5, self.p) - 1).to(self.device)
#
#     def energy(self, x):
#         if len(x.shape) == 1:
#             x = x.unsqueeze(0)
#             return (x * (self.A * x.unsqueeze(1)).sum(dim=-1)).sum(dim=-1) + (x * self.h).sum(dim=-1)
#         elif len(x.shape) == 2:
#             return (x * (self.A * x.unsqueeze(1)).sum(dim=-1)).sum(dim=-1) + (x * self.h).sum(dim=-1)
#         else:
#             raise NotImplementedError
#         # if len(x.shape) == 1:
#         #     return (x * (self.A * x.unsqueeze(-1)).sum(dim=-1)).sum()
#         # if len(x.shape) == 2:
#         #     return (x * (self.A * x.unsqueeze(-1)).sum(dim=-1)).sum(-1)
#         # return x @ (self.A * )
#
#     def grad(self, x):
#         if len(x.shape) <= 2:
#             y = x.detach().clone()
#             y.requires_grad = True
#             energy = self.energy(y)
#             energy.sum().backward()
#             if len(x.shape) == 1:
#                 energy = energy.detach()
#             else:
#                 energy = energy.detach().unsqueeze(-1)
#             return energy.detach(), y.grad
#
#     def delta(self, x : Tensor):
#         return - 2 * x
#
#     def flip_state(self, x : Tensor, idx : Tuple, *args, **kwargs):
#         z = x.clone()
#         for i in idx:
#             z[i] *= - 1
#         return z
#
#     @property
#     def num_nodes(self):
#         return self.p
#
# class GMM(object):
#     def __init__(self, m=3, p=2, d=5):
#         """
#         :param m: number of clusters
#         :param p: number of variables
#         :param d: each variable must be an integer in [-d, d]
#         """
#         self.m = m
#         self.p = p
#         self.d = d
#         self.device = torch.device(f"cuda:{cmd_args.device}" if torch.cuda.is_available() else "cpu")
#         self.rng = np.random.RandomState(cmd_args.seed)
#         self.pi = torch.rand(m, dtype=torch.float64).to(self.device)
#         self.pi /= self.pi.sum()
#         self.mu =  torch.randint(low=-self.d + 1, high=self.d, size=(m, self.p, 1), dtype=torch.float64).to(self.device)
#         B = int(d / 2) * torch.randn(size=(m, p, p), dtype=torch.float64)
#         S = B @ B.permute(0, 2, 1)
#         self.P = torch.linalg.inv(S).to(self.device)
#
#     def energy(self, x):
#         assert len(x.shape) == 3 and x.shape[2] == 1
#         v = (x - self.mu)
#         energy = self.pi * torch.exp(- v.permute(0, 2, 1) @ self.P @ v / 2).squeeze() # / self.sqdet
#         return -torch.log(energy.sum())
#
#     def init_state(self):
#         x = torch.randint(low=-self.d, high=self.d + 1, size=(1, self.p, 1))
#         return x
#
#     def flip_state(self, x, idx):
#         assert len(x.shape) == 3 and x.shape[0] == 1
#         if len(idx) == 0:
#             return x
#         else:
#             N = (2 * self.d) ** len(idx)
#             res = x.clone().repeat(N, 1, 1)
#             val = product(*[self.myrange(x[0, i, 0]) for i in idx])
#             for i, pattern in enumerate(val):
#                 for j in range(len(idx)):
#                     res[i, idx[j], 0] = pattern[j]
#             return res
#
#     def myrange(self, skip=None):
#         for i in range(-self.d, self.d+1):
#             if i != skip:
#                 yield  i


if __name__ == '__main__':
    model = BMM(p=1000, m=10)
    x = model.init_state()
    energy = model.energy(x)
    grad = model.grad(x)
    change = model.change(x)
    print('123')



# class ResBlock(Module):
#     def __init__(self, inplanes, planes, stride=1, downsample=None):
#         super().__init__()
#         self.conv1 = Conv2d(inplanes, planes, kernel_size=3, stride=stride)
#         self.bn1 = BatchNorm2d(planes)
#         self.relu = ReLU(inplace=True)
#         self.conv2 = conv3x3(planes, planes)
#         self.bn2 = BatchNorm2d(planes)
#         self.downsample = downsample
#         self.stride = stride
#
#     def forward(self, x):
#         residual = x
#
#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.relu(out)
#
#         out = self.conv2(out)
#         out = self.bn2(out)
#
#         if self.downsample is not None:
#             residual = self.downsample(x)
#
#         out += residual
#         out = self.relu(out)
#
#         return out
#
# class MNIST(Module):
#     def __init__(self):
#         super().__init__()











