import torch
import numpy as np
import pickle
import os
import torchvision
import random
cpath = os.path.dirname(__file__)


random.seed(0)
dim = 100

#dataset
n_11 = 200
n_12 = 100
n_21 = 100
n_22 = 50
data_noise_level = 0.1

#class 1
feature_size_1 = 4
feature_size_2 = 2
feature_entry_1 = random.randint(0,dim-1)
feature_entry_2 = random.randint(0,dim-1)

#class 2
feature_size_3 = 1.5
feature_size_4 = 1
feature_entry_3 = random.randint(0,dim-1)
feature_entry_4 = random.randint(0,dim-1)

#print(feature_entry_1)
#print(feature_entry_2)
#print(feature_entry_3)
#print(feature_entry_4)

x = torch.ones(100)
y = torch.zeros(100)
#print(torch.unsqueeze(torch.cat((x,y),dim=0),0))

#print(torch.cat((x,y),dim=0))

##train dataset4
dataset_train = []
dataset_11 = []
dataset_12= []
dataset_21 = []
dataset_22 = []

seed = 1
torch.manual_seed(1)

for i in range(n_11):
    patch_1 = torch.zeros(dim)
    patch_2 = data_noise_level*torch.randn(dim)
    patch_1[feature_entry_1] = feature_size_1
    patch_2[feature_entry_1] = 0
    patch_2[feature_entry_2] = 0
    patch_2[feature_entry_3] = 0
    patch_2[feature_entry_4] = 0
    if random.randint(0,1) == 0:
        #dataset_11.append(torch.cat((patch_1,patch_2),dim=0))
        dataset_train.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(0)])
    else:
        #dataset_11.append(torch.cat((patch_2, patch_1), dim=0))
        dataset_train.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(0)])

for i in range(n_12):
    patch_1 = torch.zeros(dim)
    patch_2 = data_noise_level*torch.randn(dim)
    patch_1[feature_entry_2] = feature_size_2
    patch_2[feature_entry_1] = 0
    patch_2[feature_entry_2] = 0
    patch_2[feature_entry_3] = 0
    patch_2[feature_entry_4] = 0
    if random.randint(0, 1) == 0:
        #dataset_12.append(torch.cat((patch_1,patch_2),dim=0))
        dataset_train.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(0)])
    else:
        #dataset_12.append(torch.cat((patch_2, patch_1), dim=0))
        dataset_train.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(0)])

for i in range(n_21):
    patch_1 = torch.zeros(dim)
    patch_2 = data_noise_level*torch.randn(dim)
    patch_1[feature_entry_3] = feature_size_3
    patch_2[feature_entry_1] = 0
    patch_2[feature_entry_2] = 0
    patch_2[feature_entry_3] = 0
    patch_2[feature_entry_4] = 0
    if random.randint(0, 1) == 0:
        #dataset_21.append(torch.cat((patch_1, patch_2), dim=0))
        dataset_train.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(1)])
    else:
        #dataset_21.append(torch.cat((patch_2, patch_1), dim=0))
        dataset_train.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(1)])

for i in range(n_22):
    patch_1 = torch.zeros(dim)
    patch_2 = data_noise_level*torch.randn(dim)
    patch_1[feature_entry_4] = feature_size_4
    patch_2[feature_entry_1] = 0
    patch_2[feature_entry_2] = 0
    patch_2[feature_entry_3] = 0
    patch_2[feature_entry_4] = 0
    if random.randint(0, 1) == 0:
        #dataset_22.append(torch.cat((patch_1, patch_2), dim=0))
        dataset_train.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(1)])
    else:
        #dataset_22.append(torch.cat((patch_2, patch_1), dim=0))
        dataset_train.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(1)])

##test dataset

dataset_test = []
dataset_11 = []
dataset_12= []
dataset_21 = []
dataset_22 = []

for i in range(n_11):
    patch_1 = torch.zeros(dim)
    patch_2 = data_noise_level*torch.randn(dim)
    patch_1[feature_entry_1] = feature_size_1
    patch_2[feature_entry_1] = 0
    patch_2[feature_entry_2] = 0
    patch_2[feature_entry_3] = 0
    patch_2[feature_entry_4] = 0
    if random.randint(0,1) ==0:
        #dataset_11.append(torch.cat((patch_1,patch_2),dim=0))
        dataset_test.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(0)])
    else:
        #dataset_11.append(torch.cat((patch_2, patch_1), dim=0))
        dataset_test.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(0)])

for i in range(n_12):
    patch_1 = torch.zeros(dim)
    patch_2 = data_noise_level*torch.randn(dim)
    patch_1[feature_entry_2] = feature_size_2
    patch_2[feature_entry_1] = 0
    patch_2[feature_entry_2] = 0
    patch_2[feature_entry_3] = 0
    patch_2[feature_entry_4] = 0
    if random.randint(0, 1) == 0:
        #dataset_12.append(torch.cat((patch_1,patch_2),dim=0))
        dataset_test.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(0)])
    else:
        #dataset_12.append(torch.cat((patch_2, patch_1), dim=0))
        dataset_test.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(0)])

for i in range(n_21):
    patch_1 = torch.zeros(dim)
    patch_2 = data_noise_level*torch.randn(dim)
    patch_1[feature_entry_3] = feature_size_3
    patch_2[feature_entry_1] = 0
    patch_2[feature_entry_2] = 0
    patch_2[feature_entry_3] = 0
    patch_2[feature_entry_4] = 0
    if random.randint(0, 1) == 0:
        #dataset_21.append(torch.cat((patch_1, patch_2), dim=0))
        dataset_test.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(1)])
    else:
        #dataset_21.append(torch.cat((patch_2, patch_1), dim=0))
        dataset_test.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(1)])

for i in range(n_22):
    patch_1 = torch.zeros(dim)
    patch_2 = data_noise_level*torch.randn(dim)
    patch_1[feature_entry_4] = feature_size_4
    patch_2[feature_entry_1] = 0
    patch_2[feature_entry_2] = 0
    patch_2[feature_entry_3] = 0
    patch_2[feature_entry_4] = 0
    if random.randint(0, 1) == 0:
        #dataset_22.append(torch.cat((patch_1, patch_2), dim=0))
        dataset_test.append([torch.unsqueeze(torch.cat((patch_1, patch_2), dim=0),0),torch.tensor(1)])
    else:
        #dataset_22.append(torch.cat((patch_2, patch_1), dim=0))
        dataset_test.append([torch.unsqueeze(torch.cat((patch_2, patch_1), dim=0),0),torch.tensor(1)])

torch.save(dataset_train,'dataset_train.pt')
torch.save(dataset_test,'dataset_test.pt')

#torch.load('dataset_train.pt')
#print(dataset_11)

feature_entries = [feature_entry_1,feature_entry_2,feature_entry_3,feature_entry_4]
print(feature_entries)