import numpy as np
import torch
import os
from torch.utils.data import Dataset, DataLoader

save_dir = f"./eval_datasets/"
test_sample_num = 3000
# n_qubits = 8
distribution = ["normal", "lognormal", "uniform", "exponential", "beta"]

class DistributionGenerator:
    def __init__(self, test_sample_num, n_qubits=4):
        self.test_sample_num = test_sample_num
        self.n_qubits = n_qubits
        size = int((2**n_qubits)**0.5)
        self.shape = (size, size)
    
    def generate_normal(self, name, mean=0.3, std=0.5):
        data = np.random.normal(mean, std, (self.test_sample_num, *self.shape))
        data_name = self._generate_name(name, mean=mean, std=std)
        return data_name, data
    
    def generate_lognormal(self, name, mean=0, std=1):
        data = np.random.lognormal(mean, std, (self.test_sample_num, *self.shape))
        data_name = self._generate_name(name, mean=mean, std=std)
        return data_name, data
    
    def generate_uniform(self, name, low=-1, high=1):
        data = np.random.uniform(low, high, (self.test_sample_num, *self.shape))
        data_name = self._generate_name(name, low=low, high=high)
        return data_name, data
    
    def generate_exponential(self, name, rate=1):
        data = np.random.exponential(rate, (self.test_sample_num, *self.shape))
        data_name = self._generate_name(name, rate=rate)
        return data_name, data

    def generate_beta(self, name, a=1, b=1):
        data = np.random.beta(a, b, (self.test_sample_num, *self.shape))
        data_name = self._generate_name(name, a=a, b=b)
        return data_name, data
    
    def _generate_name(self, distribution_name, **params):
        name = distribution_name + "_"
        for key, value in params.items():
            name += "{}{}".format(key, value) 
        return name + f"_{n_qubits}-qubits"


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        return torch.from_numpy(sample).float()
    
    def save_dataset_to_disk(self, path):
        torch.save(self.data, path)


if __name__ == "__main__":
    for n_qubits in [4, 6, 8]:
        for dist in distribution:
            generator = DistributionGenerator(test_sample_num, n_qubits)
            generate_func = getattr(generator, f"generate_{dist}")
            name, data = generate_func(name=dist)
            data = np.reshape(data, (data.shape[0], 1, generator.shape[0], generator.shape[1]))
            dataset = CustomDataset(data)
            
            sampler = torch.utils.data.RandomSampler(dataset)
            dataloader = torch.utils.data.DataLoader(
                    dataset,
                    batch_size=len(dataset),
                    sampler=sampler,
                    num_workers=0,
                )
            
            data_in_loader = next(iter(dataloader))
            
            dataset = {
                "images": data_in_loader,
                "encoder_params": [torch.nan for _ in range(len(data_in_loader))],
                "digits": [torch.nan for _ in range(len(data_in_loader))],
            }
            torch.save(dataset, os.path.join(save_dir, f"{name}.pt"))
            print(f"Saved {name}.pt")

