
import torch
import torch.optim as optim
import numpy as np
import pandas as pd
import math


def quad_test(size, sigma, lr, batch_size, steps = 1000):


    # Effective dimension of the space
    dim = size - sigma + 1

    def sample_loss(w_vec, s, i):
        if i < s:
            return (size/sigma) * pow(w_vec[0], 2)
        
        # Normalize by SIZE/S_PARAM so that lambda_1(H) = 2
        return (size/sigma) * pow(w_vec[i - s + 1], 2)

    def random_loss(w_vec):
        indices = np.random.choice(np.arange(size), batch_size, replace = False)
        loss = 0
        for index in indices:
            loss += sample_loss(w_vec, sigma, index)

        return loss / batch_size


    # Normalize by SIZE since the loss only matters in the "SIZE"-dimensional subspace, otherwise loss is too small?
    w_vec = torch.normal(0, pow(1, 1/2), size = (dim,), requires_grad=True)

    initial_norm = torch.norm(w_vec)
    # print('Initial norm:', initial_norm)

    optimizer = optim.SGD([w_vec], lr=lr,
                        momentum=0, weight_decay=0)

    for step in range(steps):
        optimizer.zero_grad()
        error = random_loss(w_vec)
        error.backward()
        optimizer.step()

        cur_norm = torch.norm(w_vec)
        if cur_norm > 1000 * initial_norm:
            return torch.tensor([1000.])
        if cur_norm < initial_norm / 1000:
            return torch.tensor([0.])

    final_norm = torch.norm(w_vec)
    # print('Final norm:', final_norm)

    return torch.tensor([1000.]) if final_norm / initial_norm > 1000 else final_norm / initial_norm
    

result_list = []

for lr_recip in range(0,50,1):
    # lr = 2/pow(lr_recip, 0.5)
    lr = pow(0.5, lr_recip/24)
    for batch_size in range(0,20,1):
        batch_size = int(pow(2, batch_size/3))
        for rep in range(5):
            ratio = quad_test(size=100, sigma=5, lr=lr, batch_size=batch_size, steps=1000)
            print('lr = ', lr, 'B = ', batch_size, 'ratio = ', ratio)
            result_list.append({'lr':lr, 'batch_size':batch_size, 'ratio':ratio.item()})

df = pd.DataFrame(result_list)
df = df.groupby(['lr', 'batch_size'])['ratio'].apply(np.median).round(2)
# df = df.groupby(['lr', 'batch_size'])['ratio'].apply(max).round(2)
df.to_csv('results/batch_vs_lr_sigma=5.csv')
print(df)