import time
import warnings
from itertools import cycle, islice
import itertools
import matplotlib.pyplot as plt
import numpy as np
import torch.distributions as dist

from sklearn import cluster, datasets
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset
from einops import rearrange, reduce, repeat

import torch
import sys
sys.path.append('./datasets') 
sys.path.append('../datasets')
import multidat
import pdb
from einops import rearrange, reduce, repeat, einsum


np.random.seed(0)


class SplitShiftDataset(multidat.MultiDat): 


    def __init__(self, c_num, c_lim, data_num, mode='circle',
        scale=0.4, c_dim=2, shift=0.5, 
        rival=None, upto=10, c_preset=None, 
        c_increment=2, **kwargs):

        if c_preset is None:
            self.c_dim = c_dim 

            pre_c_list = torch.linspace(0.,c_lim,c_num).unsqueeze(dim=1) # special for p = 1
            self.c_list = torch.tensor(list(itertools.product(pre_c_list, repeat=c_dim)))
            self.c_list = self.c_list[:upto]

        else:
            self.c_list = c_preset
            self.c_dim = (c_preset.shape)[1]
        
        self.c_increment = c_increment
        self.rival = rival
        self.c_lim = c_lim
        self.shift = shift * torch.tensor([1.,1.]) 
        self.c_lim = c_lim
        self.data_num = data_num
        self.mode = mode
        self.x_dim = 2 
        self.c_num = len(self.c_list)

        self.create_dat()

        super().__init__(self.pdata, self.c_list, self.rival) 


    def __len__(self):
        return len(self.data)

    def create_dat(self):
        self.pdata = torch.zeros(len(self.c_list) ,self.data_num,2)

        noisy_circles  = torch.tensor(original_datpts(mode=self.mode, n_samples=self.data_num, noise=0.)[0])

        for i in range(len(self.c_list)):
            self.pdata[i] = noisy_circles + splitfield(noisy_circles) * self.c_list[i][0] * self.c_lim/self.c_increment
            self.pdata[i] = self.pdata[i] + self.c_list[i][1] * self.shift



def original_datpts(mode: str, n_samples: int, noise: float, **kwargs): 
    if mode == 'circle':
        datpts= datasets.make_circles(n_samples=n_samples, factor=0.5, noise=noise)
    elif mode == 'moons':
        datpts = datasets.make_moons(n_samples=n_samples, noise=noise)
    else:
        raise NotImplemntedError 

    return datpts


def mfield_dir(vfield0, vfield1, gammadot):
    deltax = gammadot[0]*vfield0 + gammadot[1]*vfield1  
    return deltax

def starfield(x, n=1):
    angles =torch.arctan(x[:, 0]/ x[:, 1])
    rx = torch.tensor(torch.cos(angles*n).unsqueeze(1))
    xtilde = rx * x
    deltax = xtilde - x 
    return deltax
def splitfield(x): 
    xtilde = ((x[:,1] > 0).float() - 0.5)*2
    ydel = ((x[:,1] > 0).float() - 0.5)*2
    xdel = torch.zeros(ydel.shape)
    deltax = torch.stack([xdel, ydel]).permute([1,0])
    return deltax


#######

def sample_2d_gaussian(mean, std_dev, num_samples=100, deform=0, v=5):
    gaussian_x = dist.Normal(0, std_dev[0])
    gaussian_y = dist.Normal(0, std_dev[1])
    samples_x = gaussian_x.sample((num_samples,))
    samples_y = gaussian_y.sample((num_samples,))

    samples_x, samples_y = star_transform(samples_x, samples_y, deform=deform, v=v) 
    samples_x = samples_x + mean[0]
    samples_y = samples_y + mean[1]
    
    samples_xy = torch.stack([samples_x, samples_y])
    samples_xy = torch.transpose(samples_xy, 1,0) 
    return samples_xy

# Function to sample from a 2D Uniform
def sample_uniform_in_circle(num_samples, radius=1):
    samples = []
    while len(samples) < num_samples:
        # Uniformly sample from the square [-1, 1] x [-1, 1]
        xy = 2 * np.random.random(2) - 1
        if np.sum(xy**2) <= radius:  # Check if the point is inside the circle
            samples.append(torch.tensor(xy))
    return torch.stack(samples)


def sample_2d_uniform_star(mean, num_samples, deform=0, rot=0, v=5):
    samples_xy =  sample_uniform_in_circle(num_samples)
    samples_x, samples_y = samples_xy[:, 0],  samples_xy[:, 1]
    samples_x, samples_y = star_transform(samples_x, samples_y, deform=deform, v=v) 
    samples_x, samples_y = rot_transform(samples_x, samples_y, rot) 
    samples_x = samples_x + mean[0]
    samples_y = samples_y + mean[1]
    
    samples_xy = torch.stack([samples_x, samples_y])
    samples_xy = torch.transpose(samples_xy, 1,0) 
    return samples_xy

def rot_transform(sx, sy, rot):

    snorm = np.sqrt(sx**2 + sy**2)
    angles = torch.arctan2(sy, sx)
    angles = angles+rot
    sxx = snorm * np.cos(angles)
    syy = snorm * np.sin(angles)

    return sxx, syy

def sampled_2d_ngon(mean, num_samples, deform=0, v=5):
    pass

# Define means and standard deviations for 8 2D Gaussians
# means = torch.tensor([[i, j] for i in range(4) for j in range(2)]).float()  # 8 means in 2D
# std_devs = torch.ones(8, 2)  # Standard deviations (you can customize this)

def create_means(thetas):
    xlocs = [torch.cos(theta) for theta in thetas] 
    ylocs = [torch.sin(theta) for theta in thetas]
    xylocs =  torch.stack([torch.tensor(xlocs), torch.tensor(ylocs)])
    xylocs = torch.transpose(xylocs, 1,0) 
    return xylocs
    

def star_transform(x, y, v=5, deform=0.3):
    """
    Transform the Gaussian points to a star shape.
    This function is heuristic and might need adjustments.
    """
    r = torch.sqrt(x**2 + y**2)
    theta = torch.atan2(y, x)
    # Adjust the radius based on an oscillating function
    r_new = r * (1 + deform * torch.sin(v * theta))
    x_new, y_new = r_new * torch.cos(theta), r_new * torch.sin(theta)
    return x_new, y_new


class StarBloomDataset(multidat.MultiDat): 

    def __init__(self, num_def, num_rad, data_num, radmax=1., std=0.1, N=5,
        rival=None, upto=100, c_preset=None, mode='gauss',  **kwargs):

        if c_preset is None:
            self.c_dim = 2

            deforms = np.linspace(0, 1, num_def)
            radii = np.linspace(0, 1, num_rad)
            self.c_list = torch.tensor(list(itertools.product(deforms, radii))).float()
            self.c_list = self.c_list[:upto]

        else:
            self.c_list = c_preset
            self.c_dim = (c_preset.shape)[1]
        
        self.rival = rival
        self.data_num = data_num
        self.N = N
        self.x_dim = 2 
        self.c_num = len(self.c_list)
        self.std = 0.1
        self.mode = mode
        self.radmax = radmax
        print(f"""Using {mode} mode for the dataset.""" )
        self.create_dat()

        super().__init__(self.pdata, self.c_list, self.rival) 

    def starbloom(self, N, std, deform, radius, num_samples):
        Nabs= np.abs(N)
        ths  = torch.tensor([torch.tensor(2*torch.pi/Nabs * i ) for i in range(Nabs)])
        means = radius * create_means(ths)
        std_devs = std*torch.ones(Nabs, 2) 

        allsamples = [] 
        for i, (mean, std_dev) in enumerate(zip(means, std_devs)):
            if self.mode == 'gauss':
                samples_xy = sample_2d_gaussian(mean, std_dev, num_samples, deform=deform, v=N)
            elif self.mode == 'unif':
                samples_xy = sample_2d_uniform_star(mean, num_samples, deform=deform, v=N)
            elif self.mode == 'rot':
                samples_xy = sample_2d_uniform_star(mean, num_samples, deform=1, rot=deform, v=N) 
            else:
                raise NotImplemntedError
            allsamples.append(samples_xy)

        return torch.cat(allsamples, dim=0) 

    def create_dat(self):
        self.pdata = torch.zeros(len(self.c_list) ,self.data_num, 2)
        data_num_ball = np.abs(int(self.data_num / self.N))
        for i in range(len(self.c_list)):
            self.pdata[i] = self.starbloom(np.abs(self.N), self.std, self.c_list[i,0], self.radmax*self.c_list[i, 1], data_num_ball)





