import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import random

print(torch.cuda.is_available())

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

#%%

def create_onetask_regression_data(demon_num, test_num, input_dim, output_dim = 1, w = None):
    if w == None:
        w = torch.normal(mean = 0, std = 1, size = (output_dim, input_dim),dtype=torch.float32)
        
    x_demon = torch.empty(size = (input_dim, demon_num),dtype=torch.float32).uniform_(-1,1)
    y_demon = (w@x_demon)  # y : [output_dim, batch_size]
    token_demon = torch.cat([x_demon, y_demon], dim=0)  # token: [input_dim + output_dim, demon_num]

    # test_num = 1
    x_test = torch.empty(size=(input_dim, test_num), dtype=torch.float32).uniform_(-1, 1)
    y_test = (w @ x_test)  # y : [output_dim, test_num]
    token_test = torch.cat([x_test, y_test], dim=0)  # token: [input_dim + output_dim, test_num]

    token = torch.cat([token_demon, token_test], dim = 1)  #
    return w, token_demon, token_test

def create_multitask_regression_data(task_num, demon_num, test_num, input_dim, output_dim):
    # at least one task
    w, token_demon, token_test = create_onetask_regression_data(demon_num, test_num, input_dim, output_dim)
    for i in range(task_num-1):
        w_temp, demon_temp, test_temp = create_onetask_regression_data(demon_num, test_num, input_dim, output_dim)
        w = torch.cat([w, w_temp], dim = 0)
        token_demon = torch.cat([token_demon,demon_temp], dim = 1)
        token_test = torch.cat([token_test, test_temp], dim = 1)
    return w, token_demon, token_test

#%%

def create_onetask_data(demon_num, test_num, input_dim, output_dim = 1, w = None, datatype = "linear"):
    if w == None:
        w = torch.normal(mean = 0, std = 1, size = (output_dim, input_dim),dtype=torch.float32)
    x = torch.empty(size = (input_dim, demon_num + test_num),dtype=torch.float32).uniform_(-1,1)
    if datatype == "linear":
        y = (w @ x)  # y : [output_dim, batch_size]
    if datatype == "cos":
        x = (x * torch.pi + torch.pi)/2  # x: [0,pi]
        y = torch.cos(w @ x) # y = cos(w*x)
    if datatype == "exp":
        y = torch.exp(w @ x) # y = exp(w*x)
    
    token = torch.cat([x, y], dim=0)  # token: [input_dim + output_dim, demon_num + test_num]
    token_demon, token_test = torch.split(token, split_size_or_sections=[demon_num, test_num], dim=1)
    return w, token_demon, token_test


def create_multitask_data(task_num, demon_num, test_num, input_dim, output_dim, datatype = "linear"):
    # at least one task
    w, token_demon, token_test = create_onetask_data(demon_num, test_num, input_dim, output_dim, datatype = datatype)
    for i in range(task_num-1):
        w_temp, demon_temp, test_temp = create_onetask_data(demon_num, test_num, input_dim, output_dim, datatype = datatype)
        w = torch.cat([w, w_temp], dim = 0)
        token_demon = torch.cat([token_demon,demon_temp], dim = 1)
        token_test = torch.cat([token_test, test_temp], dim = 1)
    return w, token_demon, token_test



class token_Data(Dataset):
    def __init__(self, w, tokens, input_dim, output_dim):
        super(Dataset, self).__init__()
        self.w = w
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.tokens = tokens

    def __getitem__(self, index):
        # data = self.tokens[: self.input_dim, index]
        data = self.tokens[:, index]
        label = self.tokens[self.input_dim:, index]
        return data, label

    def __len__(self):
        return self.tokens.shape[1]



