import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

class randomfeature(nn.Module):
    def __init__(self, dim, sigma):

        super(randomfeature, self).__init__()
        self.linear = nn.Linear(dim, dim, bias=False)
        self.output = nn.Linear(dim, 1, bias=False)
        nn.init.normal_(self.output.weight, mean=0.0, std=sigma)

    def forward(self, x):

        x = F.relu(self.linear(x))
        return self.output(x)


class teacher(nn.Module):
    def __init__(self, dim):

        super(teacher, self).__init__()
        self.linear = nn.Linear(dim, dim, bias=False)
        self.output = nn.Linear(dim, 1, bias=False)

    def forward(self, x):

        x = F.relu(self.linear(x))
        return self.output(x)


def sample_spherical (num_samples, dim):
    vec = np.random.normal(0.0, 1.0,  size=(num_samples, dim))
    vec = vec / np.linalg.norm(vec, axis = 1, keepdims=True)
    vec = torch.tensor(vec, dtype=torch.float)
    return vec



eta = 0.05
num_samples = 130
dim = 200

input = sample_spherical(num_samples, dim).cuda()
train_input = input[:100]
test_input = input[100:]
pre_model = teacher(dim).cuda()
with torch.no_grad():
    label = pre_model(input)
train_label = label[:100]
test_label = label[100:]

criterion = nn.MSELoss()
len_list = []
trainloss_list = []
testloss_list = []


for i in tqdm(range(-5, 5)):
    sigma = 2 ** i
    len = 0
    error = 1
    model = randomfeature(dim, sigma).cuda()
    model.linear.weight.requires_grad = False
    optimizer = torch.optim.SGD(model.parameters(), lr=eta, momentum=0.9)


    while error > 1e-8:
        w_before = [w.reshape(-1) for w in model.parameters()]
        w_before = torch.cat(w_before, -1)
        optimizer.zero_grad()
        train_loss = criterion(model(train_input), train_label)
        train_loss.backward()
        optimizer.step()
        w_after = [w.reshape(-1) for w in model.parameters()]
        w_after = torch.cat(w_after, -1)
        len = len + torch.norm(w_after - w_before)
        train_loss_after = criterion(model(train_input), train_label)
        error = train_loss - train_loss_after

    len_list.append(len.detach().cpu().numpy())
    trainloss_list.append(train_loss_after.detach().cpu().numpy())
    test_loss = criterion(model(test_input), test_label)
    testloss_list.append(test_loss.detach().cpu().numpy())


def randomfeature_length():

    return len_list, trainloss_list, testloss_list
