
import torch
import torch.optim as optim
import numpy as np
import pandas as pd
import math


def quad_test(size, sigma, lr, batch_size, steps = 1000):


    # Effective dimension of the space
    dim = size - sigma + 1

    def sample_loss(w_vec, s, i):
        if i < s:
            return (size/sigma) * pow(w_vec[0], 2)
        
        # Normalize by SIZE/S_PARAM so that lambda_1(H) = 2
        return (size/sigma) * pow(w_vec[i - s + 1], 2)

    def random_loss(w_vec):
        indices = np.random.choice(np.arange(size), batch_size, replace = False)
        loss = 0
        for index in indices:
            loss += sample_loss(w_vec, sigma, index)

        return loss / batch_size

    # Normalize by SIZE since the loss only matters in the "SIZE"-dimensional subspace, otherwise loss is too small?
    w_vec = torch.normal(0, pow(1, 1/2), size = (dim,), requires_grad=True)

    initial_norm = torch.norm(w_vec)
    # print('Initial norm:', initial_norm)

    optimizer = optim.SGD([w_vec], lr=lr,
                        momentum=0, weight_decay=0)

    for step in range(steps):
        optimizer.zero_grad()
        error = random_loss(w_vec)
        error.backward()
        optimizer.step()

        cur_norm = torch.norm(w_vec)
        if cur_norm > 1000 * initial_norm:
            return torch.tensor([1000.])
        if cur_norm < initial_norm / 100:
            return torch.tensor([0.])

    final_norm = torch.norm(w_vec)
    # print('Final norm:', final_norm)

    return torch.tensor([1000.]) if final_norm / initial_norm > 1000 else final_norm / initial_norm
    

result_list = []

for sigma in range(1,50,1):
    for batch_size in range(1,80,1):
        for rep in range(3):
            ratio = quad_test(size=100, sigma=sigma, lr=0.5, batch_size=batch_size, steps=1000)
            print('sigma = ', sigma, 'B = ', batch_size, 'ratio = ', ratio)
            result_list.append({'sigma':sigma, 'batch_size':batch_size, 'ratio':ratio.item()})

df = pd.DataFrame(result_list)
df = df.groupby(['sigma', 'batch_size'])['ratio'].apply(np.median).round(2)
df.to_csv('results/sigma_vs_batch_lr=0.5.csv')
print(df)