import random
import numpy as np
import torch

# generate data of a random sinusoidal task, y=A*sin(x+phi)+eps, eps~N(0,s^2)
class SineDataset(torch.utils.data.Dataset):
    def __init__(self, amplitude_range, phase_range, noise_std, x_range, num_samples):
        super().__init__()
        self.amplitude_range = [a for a in amplitude_range]
        self.phase_range = [phi for phi in phase_range]
        self.noise_std = noise_std
        self.x = torch.linspace(start=x_range[0], end=x_range[1], steps=num_samples)
    def __len__(self):
        return 100000
    def __getitem__(self, index):
        y = self.generate_label()
        y = y + torch.randn_like(input=y) * self.noise_std
        return [self.x, y]
    def generate_label(self):
        amplitude = torch.rand(size=(1,)) * (self.amplitude_range[1] - self.amplitude_range[0]) + self.amplitude_range[0]
        phase = torch.rand(size=(1,)) * (self.phase_range[1] - self.phase_range[0]) + self.phase_range[0]
        y = amplitude * torch.sin(input=self.x + phase)
        return y

# generate a data for a task, y=ax+b+eps, eps~N(0,s^2)
class LineDataset(torch.utils.data.Dataset):
    def __init__(self, slope_range, intercept_range, x_range, num_samples, noise_std):
        super().__init__()
        self.slope_range = [a for a in slope_range]
        self.intercept_range = [b for b in intercept_range]
        self.noise_std = noise_std
        self.x = torch.linspace(start=x_range[0], end=x_range[1], steps=num_samples)
    def __getitem__(self, index):
        y = self.generate_label()
        y = y + torch.randn_like(input=y) * self.noise_std
        return [self.x, y]
    def __len__(self):
        return 100000
    def generate_label(self):
        slope = torch.rand(size=(1,)) * (self.slope_range[1] - self.slope_range[0]) + self.slope_range[0]
        intercept = torch.rand(size=(1,)) * (self.intercept_range[1] - self.intercept_range[0]) + self.intercept_range[0]
        y = slope * self.x + intercept
        return y

def get_dataloaders_sine_line(nsamples):
    regression_dataset = torch.utils.data.ConcatDataset(
        datasets=[
            SineDataset(amplitude_range=[0.1, 5], phase_range=[0, np.pi], noise_std=0.3, x_range=[-5, 5], num_samples=nsamples),
            LineDataset(slope_range=[-3, 3], intercept_range=[-3, 3], noise_std=0.3, x_range=[-5, 5], num_samples=nsamples)
        ]
    )
    train_dl = torch.utils.data.DataLoader(dataset=regression_dataset, batch_size=1, shuffle=True)
    test_dl = torch.utils.data.DataLoader(dataset=regression_dataset, batch_size=1, shuffle=True)
    return train_dl, test_dl

def split_support_query(eps_data, k_shot):
    data = {}
    v_ids = [i for i in range(eps_data[0].numel())]
    k_ids = random.sample(population=v_ids, k=k_shot)
    v_ids = [v for v in v_ids if v not in k_ids]
    eps_data_batch = [eps_data[i].T for i in range(len(eps_data))]
    data['x_t'] = eps_data_batch[0][k_ids]
    data['y_t'] = eps_data_batch[1][k_ids]
    data['x_v'] = eps_data_batch[0][v_ids]
    data['y_v'] = eps_data_batch[1][v_ids]

    return data
