import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# from model import MLP
from modules.model import MLP
from modules.mlp import MLP as MLP2


seed = 2
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.set_printoptions(formatter={'float': '{: 0.3f}'.format})


def simple_dataset_linear(num_sample, num_feat, num_noise_feat):
    x = torch.rand((num_sample, num_feat + num_noise_feat))
    fc = nn.Linear(num_feat + num_noise_feat, 1)
    torch.nn.init.xavier_uniform_(fc.weight)
    fc.weight[:, num_feat:] *= 1e-7

    print('gen weight:')
    print(fc.weight.detach().cpu().numpy())
    print('')

    with torch.no_grad():
        fc.eval()
        y = fc(x)
        y = (y > torch.median(y)).long().reshape(-1)

    return x[:num_sample // 2], y[:num_sample // 2], x[num_sample // 2:], y[num_sample // 2:]


def simple_dataset_nonlinear(num_sample, num_feat, num_noise_feat):
    x = torch.rand((num_sample, num_feat))
    mlp = MLP2(num_feat, 512, 128, 512, 64, 64, 1, act_class=nn.LeakyReLU)
    # for param in mlp.parameters():
    #     torch.nn.init.normal_(param, 0, 1)

    y = mlp(x)
    y = (y > torch.median(y)).long().reshape(-1).detach().numpy()
    # y = y.detach().numpy()
    mean = torch.mean(x).item()
    std = torch.std(x).item()

    x = x.detach().numpy()
    noise = np.random.randn(num_sample, num_noise_feat) * std + mean

    x_out = np.concatenate([x, noise], -1)
    idx = np.random.permutation(num_feat + num_noise_feat)
    selected = np.where(idx < num_feat)[0]
    x_out = x_out[:, idx]
    return torch.tensor(x_out).float(), torch.tensor(y), selected
    # return x[:num_sample // 2], y[:num_sample // 2], x[num_sample // 2:], y[num_sample // 2:]


def data_sample(x, y, batch_size, dev):
    idx = np.random.randint(0, x.shape[0], batch_size)
    return x[idx].to(dev), y[idx].to(dev)