import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm


class Sine(Dataset):

    def __init__(self, n_support, n_query, n_tasks, phase, noise=0, device=None):

        self.n_tasks = n_tasks
        self.n_s = n_support
        self.n_q = n_query
        self.noise = noise
        self.phase = phase
        self.device = device

        self.min_params = np.array([0, 0])
        self.max_params = np.array([5., 2*np.pi])

        self.A =   np.random.uniform(self.min_params[0], self.max_params[0], size=n_tasks)
        self.Phi = np.random.uniform(self.min_params[1], self.max_params[1], size=n_tasks)

        self.X = np.random.uniform(-5, 5, size=(n_support + n_query, n_tasks))
        self.Y = np.zeros((n_support + n_query, n_tasks))

        for i in range(n_tasks):
            self.Y[:, i] = self.A[i] * np.sin(self.X[:, i] + self.Phi[i])

    def f(self, x, a, phi):
        return a * np.sin(x + phi)

    def __len__(self):
        return 100

    def __getitem__(self, idx):

        i = np.random.randint(0, self.n_tasks)

        x_s = torch.from_numpy(self.X[:self.n_s, i]).view(-1, 1).float()
        x_q = torch.from_numpy(self.X[self.n_s:, i]).view(-1, 1).float()
        y_s = torch.from_numpy(self.Y[:self.n_s, i]).view(-1, 1).float()
        y_q = torch.from_numpy(self.Y[self.n_s:, i]).view(-1, 1).float()

        if self.phase == "test":
            y_s = y_s + torch.normal(torch.zeros(y_s.shape), torch.ones(y_s.shape)*self.noise)

        return x_s.to(self.device), y_s.to(self.device), x_q.to(self.device), y_q.to(self.device)

    def get_max_task_batch(self):

        x = np.random.uniform(-5, 5, size=(self.n_s+self.n_q))

        x_s = torch.from_numpy(x[:self.n_s]).view(-1, 1).float()
        y_s = torch.from_numpy(self.f(x[:self.n_s], self.max_params[0], self.max_params[1])).view(-1, 1).float()
        x_q = torch.from_numpy(x[self.n_s:]).view(-1, 1).float()
        y_q = torch.from_numpy(self.f(x[self.n_s:], self.max_params[0], self.max_params[1])).view(-1, 1).float()

        return x_s.to(self.device), y_s.to(self.device), x_q.to(self.device), y_q.to(self.device)

    def get_fixed_task_batch(self):

        A = 2
        Phi = np.pi*(7/4)

        x = np.random.uniform(-5, 5, size=(self.n_s+self.n_q))

        x_s = torch.from_numpy(x[:self.n_s]).view(-1, 1).float()
        y_s = torch.from_numpy(self.f(x[:self.n_s], A, Phi)).view(-1, 1).float()
        x_q = torch.from_numpy(x[self.n_s:]).view(-1, 1).float()
        y_q = torch.from_numpy(self.f(x[self.n_s:], A, Phi)).view(-1, 1).float()

        return x_s.to(self.device), y_s.to(self.device), x_q.to(self.device), y_q.to(self.device)


class Poly(Dataset):

    def __init__(self, n_support, n_query, n_tasks, phase, degree, noise=0, device=None):

        self.n_tasks = n_tasks
        self.n_s = n_support
        self.n_q = n_query
        self.noise = noise
        self.phase = phase
        self.device = device

        self.params = np.random.uniform(-5, 5, size=(n_tasks, degree+1))

        self.X = np.random.uniform(-1, 1, size=(n_support + n_query, n_tasks))
        self.Y = np.zeros((n_support + n_query, n_tasks))

        for i in range(n_tasks):
            self.Y[:, i] = self.params[i, 0]
            for j in range(1, degree+1):
                self.Y[:, i] += self.params[i, j] * (self.X[:, i] ** j)

    def __len__(self):
        return 100

    def __getitem__(self, idx):

        i = np.random.randint(0, self.n_tasks)

        x_s = torch.from_numpy(self.X[:self.n_s, i]).view(-1, 1).float()
        x_q = torch.from_numpy(self.X[self.n_s:, i]).view(-1, 1).float()
        y_s = torch.from_numpy(self.Y[:self.n_s, i]).view(-1, 1).float()
        y_q = torch.from_numpy(self.Y[self.n_s:, i]).view(-1, 1).float()

        if self.phase == "test":
            y_s = y_s + torch.normal(torch.zeros(y_s.shape), torch.ones(y_s.shape)*self.noise)

        return x_s.to(self.device), y_s.to(self.device), x_q.to(self.device), y_q.to(self.device)


if __name__ == '__main__':

    dataset = Sine(5, 50, 10000, "train")
    dataloader = DataLoader(dataset, batch_size=10, shuffle=False)
    for batch in dataloader:
        xs = batch[0]
        ys = batch[1]
        xq = batch[2]
        yq = batch[3]
