from dataclasses import replace
import ipdb
import torch
import numpy as np
from scipy.stats import truncnorm
import ipdb


class SinePolyDataset(torch.utils.data.Dataset):
 
    def __init__(self, S, dim, N, size, mode="", seed=0, noise_std=0, shift=0, constant_support=False):
        """
        amplitudes: min,max
        phases: min,max
        """

        self.N = N
        self.size = size
        self.mode = mode
        self.dim = dim
        self.seed = seed
        self.size = size
        self.noise_std = noise_std
 
        self.support_size = S
        self.query_size = 50
        self.total_length = self.support_size + self.query_size
 
        # Generate dataset
        rng = np.random.default_rng(seed)
        self.rng = rng

        if constant_support:
            x_support = np.linspace(-5, 5, self.support_size)[None, :].repeat(self.size, 0)[:, :, None]
            x_query = np.linspace(-5, 5, self.query_size)[None, :].repeat(self.size, 0)[:, :, None]
        else:
            x_support = rng.uniform(-5, 5, size=(self.size, self.support_size, self.dim)) #.repeat(size, axis=0)

            x_query = rng.uniform(-5, 5, size=(self.size, self.query_size, self.dim)) #.repeat(size, axis=0)

        params_A = np.zeros((size, N + 1)) # Added N + 1
        params = self.get_params()

        if shift:
            y_support = self.y_function_support(x_support, params)
        else:
            print("No shift: support = query")
            y_support = self.y_function_query(x_support, params)
        y_query = self.y_function_query(x_query, params)

        for k, param in enumerate(params):
            params_A[:, k] = param[:, 0]

        self.x_support = x_support
        self.x_query = x_query

        self.y_support = y_support
        self.y_query = y_query

        self.params_A = params_A

    def y_function_support(self, x, params):
        x = x[:, :, 0]
        p1 = params[0]
        p2 = params[1]
        #to_return = p1 * x + p2
        to_return = p1 * np.cos(x + p2)
        return to_return

    def y_function_query(self, x, params):
        x = x[:, :, 0]
        p1 = params[0]
        p2 = params[1]
        to_return = p1 * np.sin(x + p2)
        return to_return


    def get_params(self, linspace=False):
        size = self.size
        rng = self.rng
        A_interval = [0.1, 5.0]
        b_interval = [0, np.pi]
        A = self.get_random_param(rng, A_interval, size, linspace)
        b = self.get_random_param(rng, b_interval, size, linspace)
        return [A, b]


    def get_random_param(self, rng, param_interval, size, linspace=False):
        param_interval = np.array(param_interval)
        num_interval = len(param_interval.shape)
        if num_interval == 1:
            params = rng.uniform(param_interval[0], param_interval[1], size = (size, 1))
            if linspace:
                params = np.linspace(param_interval[0], param_interval[1], size)[:, None]
            
        elif num_interval == 2:
            param1 = self.get_random_param(rng, param_interval[0], size // 2, linspace)
            param2 = self.get_random_param(rng, param_interval[1], size // 2, linspace)
            params = np.stack((param1, param2)).reshape(-1, 1)
            np.random.shuffle(params)

        return params

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        x_support = self.x_support[idx]
        x_query = self.x_query[idx]
        y_support = self.y_support[idx] + np.random.randn(*self.y_support[idx].shape) * self.noise_std
        y_query = self.y_query[idx] #+ np.random.randn(*self.y_query[idx].shape) * self.noise_std

        params_A = self.params_A[idx]

        x_support = torch.from_numpy(x_support).float()
        y_support = torch.from_numpy(y_support).float().unsqueeze(-1)
 
        x_query = torch.from_numpy(x_query).float()
        y_query = torch.from_numpy(y_query).float().unsqueeze(-1)

        params = torch.Tensor(params_A).float().squeeze(-1)
        return x_support, y_support, x_query, y_query, params

    def sample_big_batch(self, batch_size):
        if batch_size > self.size:
            batch_size = self.size

        random_indices = self.rng.choice(np.arange(self.size), size=batch_size, replace=False)
        for idx in random_indices:
            yield self.__getitem__(idx)
        #yield [self.__getitem__(idx) for idx in random_indices]
        