# data generation
# signal, different direction are the 
import argparse
from cProfile import run
import numpy as np
import torch
import os
from matplotlib import pyplot as plt
parser = argparse.ArgumentParser()
# Task parameter
parser.add_argument('-n', type=int, default=100)
parser.add_argument('-p', type=int, default=100)
parser.add_argument('-snr',type = float, default=10)
parser.add_argument('-save_path',type = str, default = './')
parser.add_argument('-rho',type = float,default = 0)



args = parser.parse_args()
print(args)

if(not os.path.exists(args.save_path)):
    os.makedirs(args.save_path)

p = args.p
n = args.n
snr = args.snr
rho = args.rho


for _ in range(8):
    # data generation
    # signal, different direction are the 
    mu = np.zeros((p,))
    mu[0] = snr
    features = []
    labels = []
    for i in range(n):
        label = np.random.uniform()
        label = 2*(label > 0.5) - 1
        labels.append(label)
        flip = np.random.uniform()
        if(flip < rho):
            labels[-1]  = -labels[-1]
        feature = mu*label + np.random.normal(size = (p,))
        feature = torch.tensor(feature)
        features.append(feature)
    features = torch.stack(features)
    labels = torch.tensor(labels)
    onestep = (features*labels.unsqueeze(dim = 1)).mean(dim = 0)
    best = onestep[0]/onestep.norm()
    # training
    hatmu = torch.zeros((p,),dtype = float)
    hatmu.requires_grad = True
    epochs = 2000
    lr = 1e-4
    hatmus = []
    losses = []
    running_loss = n
    while(running_loss/n > 0.05):
        running_loss = 0
        for iter in range(n):
            # loss = torch.exp(-labels[iter]*torch.dot(hatmu,features[iter]))
            loss = torch.log(1 + torch.exp(-labels[iter]*torch.dot(hatmu,features[iter])))
            loss.backward()
            with torch.no_grad():
                hatmu = hatmu - lr*hatmu.grad
                running_loss += loss.item()
            hatmu.requires_grad = True
        hatmus.append((hatmu[0]/hatmu.norm()).detach().numpy())
        losses.append(running_loss/n)
        # print(running_loss/n)
    plt.plot(hatmus,label ='sgd')
    plt.plot(losses,label = 'loss')
    plt.plot([best]*len(losses),label = 'one step')
    plt.legend()
    plt.savefig(args.save_path+'/fig'+str(_))
    plt.close()
    torch.save((best,hatmus,losses),args.save_path+'/stats'+str(_))
