import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


class linearRegression(nn.Module):
    def __init__(self, dim, sigma):

        super(linearRegression, self).__init__()
        self.linear = nn.Linear(dim, 1, bias=False)
        nn.init.normal_(self.linear.weight, mean=0.0, std=sigma)

    def forward(self, x):

        return self.linear(x)


def sample_spherical (num_samples, dim):
    vec = np.random.normal(0.0, 1.0,  size=(num_samples, dim))
    vec = vec / np.linalg.norm(vec, axis = 1, keepdims=True)
    vec = torch.tensor(vec, dtype=torch.float)
    return vec



eta = 0.05
num_samples = 130
dim = 200

input = sample_spherical(num_samples, dim).cuda()
train_input = input[:100]
test_input = input[100:]
beta = torch.rand(size=(dim, 1)).cuda()
beta = beta / torch.norm(beta)
label = torch.matmul(input, beta)
train_label = label[:100]
test_label = label[100:]

criterion = nn.MSELoss()
sig_list = []
len_list = []
trainloss_list = []
testloss_list = []


for i in tqdm(range(-5, 5)):
    sigma = 2 ** i
    len = 0
    error = 1
    sig_list.append(sigma)
    model = linearRegression(dim, sigma).cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=eta)


    while error > 1e-8:
        w_before = [w.reshape(-1) for w in model.parameters()]
        w_before = torch.cat(w_before, -1)
        optimizer.zero_grad()
        train_loss = criterion(model(train_input), train_label)
        train_loss.backward()
        optimizer.step()
        w_after = [w.reshape(-1) for w in model.parameters()]
        w_after = torch.cat(w_after, -1)
        len = len + torch.norm(w_after - w_before)
        train_loss_after = criterion(model(train_input), train_label)
        error = train_loss - train_loss_after

    len_list.append(len.detach().cpu().numpy())
    trainloss_list.append(train_loss_after.detach().cpu().numpy())
    test_loss = criterion(model(test_input), test_label)
    testloss_list.append(test_loss.detach().cpu().numpy())



def linear_length():

    return len_list, trainloss_list, testloss_list
