import torch
from torch.distributions import Normal, MultivariateNormal
import jax.numpy as jnp
import flax
import jax
import numpy as np
from sklearn.datasets import make_circles

class NoisyModelDataset(torch.utils.data.Dataset):
    def __init__(self, model, model_variables, N, T, dim_u, mu, sigma):
        print(f"{mu.shape=}")
        print(f"{sigma.shape=}")
        assert mu.shape == (dim_u,)
        assert (
            dim_u == 1 and sigma.shape == mu.shape
            or dim_u > 1 and sigma.shape == (dim_u, dim_u)
        )
        self.model = model
        self.model_variables = model_variables
        self.N = N
        self.T = T
        self.dim_u = dim_u
        self.mu = mu
        self.sigma = sigma

        self.u = self.generate_normal_data()

    def generate_normal_data(self):
        distribution = Normal if self.dim_u == 1 else MultivariateNormal
        return distribution(self.mu, self.sigma).sample((self.N, self.T))

    def generate_output(self, noise=None, batch_size=16):
        state, params = flax.core.pop(self.model_variables, 'params')
        self.y = torch.from_numpy(np.array(self.model.apply({'params': params},
                                                            jnp.array(self.u))))
        if noise:
            self.y += noise.sample(self.y.shape)
        return self.y

    def __getitem__(self, idx):
        return self.u[idx], self.y[idx]

    def __len__(self):
        return self.N

class NoisyClassifDataset(torch.utils.data.Dataset):
    def __init__(self, model, model_variables, N, T, dim_u, mu, sigma,
                 k, d, u_T_arr, neg_T_arr,
                 p_label_noise = 0,
                 const_us=False,
                 seed=42):
        print(f"{mu.shape=}")
        print(f"{sigma.shape=}")
        assert mu.shape == (dim_u,)
        assert (
            dim_u == 1 and sigma.shape == mu.shape
            or dim_u > 1 and sigma.shape == (dim_u, dim_u)
        )
        self.model = model
        self.model_variables = model_variables
        self.N = N
        self.T = T
        self.dim_u = dim_u
        self.mu = mu
        self.sigma = sigma
        self.seed = seed
        self.k = k
        self.d = d

        self.u, self.labels = self.generate_binary_data(u_T_arr,
                                                        neg_T_arr,
                                                        const_us=const_us)

        if p_label_noise:
            flips = torch.bernoulli(p_label_noise * torch.ones_like(self.labels))
            self.labels += flips
            self.labels %= 2
        #self.u /= 2**k
        #self.u[:, -1, :] /= 2**k
        #self.u[:, :-1, :] -= 0.5
#        self.u[-1] /= 2**k

    def generate_binary_data_old(self):
        N_half = self.N // 2
        key = jax.random.PRNGKey(self.seed)
        key_pos, key_neg, key_neg_2 = jax.random.split(key, 3)
        d_pows = jnp.pow(self.d * jnp.ones((self.k)),
                         jnp.arange(1, self.k+1,1))[::-1]
#        u_T = jnp.sum(u_T_arr * d_pows).item()
#        neg_T = jnp.sum(neg_T_arr * d_pows).item()

        u = jax.random.randint(key_pos, (N_half, self.T - 2), 0, 2)
        u = jnp.hstack([jnp.zeros((N_half, 1)), u])
        sums = (self.T - 1) * jnp.sum(u[:, -self.k:]*d_pows, axis=1)
        u = jnp.hstack([u, sums[:, None]])
        #labels = 1 * (u[:, -1] >= 2 ** self.k - 1)
        labels_pos = torch.ones(u.shape[0])

        negs = jax.random.randint(key_neg, (N_half, self.T - 2), 0, 2)
        negs = jnp.hstack([jnp.zeros((N_half, 1)), negs])
        sums = jnp.sum(negs[:, -self.k:]*d_pows, axis=1)
        negs = jnp.hstack([negs, sums[:, None]])

        # adding random number to last element
        negs = np.array(negs)
        adds = np.array(jax.random.randint(key_neg_2, (N_half,),
                                           minval=1,
                                           maxval=2**(self.k+1)-1))
        negs[:, -1] += adds
        negs[:, -1] %= 2**(self.k + 1)-1
        negs[: -1] *= self.T - 1
        labels_neg = torch.zeros(negs.shape[0])

        u = torch.from_numpy(np.array(u.reshape((*u.shape, 1))))
        negs = torch.from_numpy(np.array(negs.reshape((*negs.shape, 1))))
        data = torch.vstack([u, negs])
        labels = torch.cat([labels_pos, labels_neg])

        return data, labels

    def generate_binary_data(self, u_T_arr, neg_T_arr, const_us=False):
        assert len(u_T_arr) == self.k
        assert len(neg_T_arr) == self.k

        N_half = self.N // 2
        key = jax.random.PRNGKey(self.seed)
        key_pos, key_neg, key_neg_2 = jax.random.split(key, 3)
        d_pows = jnp.pow(self.d * jnp.ones((self.k)),
                         jnp.arange(1, self.k+1,1))[::-1]
        u_T = jnp.sum(u_T_arr * d_pows).item()
        neg_T = jnp.sum(neg_T_arr * d_pows).item()

        u = jax.random.randint(key_pos, (N_half, self.T - 2 - self.k), 0, 2)
        if const_us:
            key1, key2 = jax.random.split(key_neg_2, 2)
            const = jax.random.uniform(key1, shape=(1,), minval=-1.0, maxval=1.0)
            u = const.item() * jnp.ones((N_half, self.T - 2 - self.k))
        u = jnp.hstack([jnp.zeros((N_half, 1)),
                        u,
                        jnp.expand_dims(u_T_arr,
                                        0).repeat(N_half, axis=0),
                        (self.T - 1) * u_T * jnp.ones((N_half, 1))])
        labels_pos = torch.ones(u.shape[0])

        negs = jax.random.randint(key_neg, (N_half, self.T - 2 - self.k), 0, 2)
        if const_us:
            const = jax.random.uniform(key2, shape=(1,), minval=-1.0, maxval=1.0)
            negs = const.item() * jnp.ones((N_half, self.T - 2 - self.k))
        negs = jnp.hstack([jnp.zeros((N_half, 1)),
                           negs,
                           jnp.expand_dims(neg_T_arr,
                                           0).repeat(N_half, axis=0),
                           (self.T - 1) * neg_T * jnp.ones((N_half, 1))])
        labels_neg = torch.zeros(negs.shape[0])

        u = torch.from_numpy(np.array(u.reshape((*u.shape, 1))))
        negs = torch.from_numpy(np.array(negs.reshape((*negs.shape, 1))))
        data = torch.vstack([u, negs])
        labels = torch.cat([labels_pos, labels_neg])

        return data, labels


    def generate_normal_data(self):
        distribution = Normal if self.dim_u == 1 else MultivariateNormal

        u_1 = distribution(self.mu, self.sigma).sample((self.N, self.T))
        labels_1 = torch.zeros(u_1.shape[0])
        u_2 = distribution(-self.mu, 5 * self.sigma).sample((self.N, self.T))
        labels_2 = torch.ones(u_2.shape[0])
        u = torch.cat([u_1, u_2])
        labels = torch.cat([labels_1, labels_2])
        return u, labels

    def generate_output(self, noise=None, batch_size=16):
        #state, params = flax.core.pop(self.model_variables, 'const')
        self.y = torch.from_numpy(np.array(self.model.apply(self.model_variables,
                                                            jnp.array(self.u),
                                                            rngs={'params': jax.random.PRNGKey(0),
                                                                  'const': jax.random.PRNGKey(1)})))
        assert not self.y.isnan().any()
        if noise:
            self.y += noise.sample(self.y.shape)
        return self.y

    def __getitem__(self, idx):
        return self.u[idx], self.y[idx], self.labels[idx]

    def __len__(self):
        return self.u.shape[0]


def create_spiral_dataset(N=1000, N_traj = 1000,
                          radius = 2 * torch.pi, T = 5,
                          noise_scale=5,
                          sorting = "none",
                          normed = True):
    assert sorting in ["outward", "inward", "none"]
    pi = torch.pi
    res_as, res_bs = [], []
    for _ in range(N_traj):
        theta = torch.sqrt(torch.rand(N))*radius
        if sorting == "outward":
            theta = torch.sort(theta).values
        elif sorting == "inward":
            theta = torch.sort(theta, descending=True).values

        r_a = 2*theta + pi
        data_a = torch.stack([torch.cos(theta)*r_a, torch.sin(theta)*r_a]).T
        x_a = data_a + noise_scale * torch.randn(N,2)

        r_b = -2*theta - pi
        data_b = torch.stack([torch.cos(theta)*r_b, torch.sin(theta)*r_b]).T
        x_b = data_b + noise_scale * torch.randn(N,2)

        res_a = torch.cat([x_a, torch.zeros(N, 1)], dim=1)
        res_b = torch.cat([x_b, torch.ones(N, 1)], dim=1)

        res_as.append(res_a)
        res_bs.append(res_b)

       # res = torch.cat([res_a, res_b], dim=0)
    res = torch.cat([torch.stack(res_as), torch.stack(res_bs)])
    res = res[torch.randperm(res.size()[0])]
    if normed:
        res[:, :, :2] /= torch.max(torch.abs(res[:, :, :2]))
    return  res, torch.arange(N + 1) * (T / N)

def create_circle_dataset(N, T, factor=0.7, noise=0.15, normed=True):
    data = []
    for _ in range(N):
        X, y = make_circles(n_samples=2 * T,
                            noise=noise,
                            factor=factor,
                            shuffle=False)
        a = torch.cat([torch.from_numpy(X[y==0]),
                       torch.from_numpy(y[y==0]).unsqueeze(1)],
                      dim=1)
        data.append(a)

        b = torch.cat([torch.from_numpy(X[y==1]),
                       torch.from_numpy(y[y==1]).unsqueeze(1)],
                      dim=1)
        data.append(b)
    data = torch.stack(data)
    data = data[torch.randperm(data.size()[0])]
    if normed:
        data[:, :, :2] /= torch.max(torch.abs(data[:, :, :2]))
    return data



if __name__ == "__main__":
    model = torch.nn.Linear(3, 2)
    N = 10
    T = 5
    dim_u = 3
    mu = torch.zeros(dim_u)
    sigma = torch.eye(dim_u)

    dataset = NoisyModelDataset(model, N, T, dim_u, mu, sigma)
    print(dataset.u)
