import torch
import matplotlib.pyplot as plt
import random
import copy
import torch.optim.lr_scheduler as lr_scheduler
from Trace import Covariance
from Visualization import funcaverage

noise = 4.0
def NoisyGD(input, LossFunctions, eps, lr, decay_rate, seed):
    #torch.manual_seed(seed)
    #random.seed(seed)
    x = copy.deepcopy(input)
    AccCov = torch.zeros(2, 2)
    optimizer = torch.optim.SGD([x], lr=lr)
    lmbda = lambda epoch: decay_rate ** epoch
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lmbda)
    PositiveHits = 0
    Traj = torch.zeros(2, eps + 1)
    #S = list(range(len(LossFunctions)))
    for ep in range(eps):
        Traj[:, ep] = x.detach()
        optimizer.zero_grad()
        # Subsample the loss functions and construct the loss
        loss = 0
        for k in range(len(LossFunctions)):
            #print(k)
            loss += LossFunctions[k](x)
        #print(x)
        loss /= len(LossFunctions)
        loss.backward()
        x.grad = x.grad + torch.randn_like(x.grad) * noise
        #y.grad = y.grad + torch.randn_like(y.grad) * noise
        optimizer.step()
        scheduler.step()
        #AccCov += Covariance(x, LossFunctions) * (optimizer.param_groups[0]["lr"] ** 2)
        if x[0] > 0:
            PositiveHits += 1
    loss = funcaverage(LossFunctions)
    # No use for H
    H = AccCov
    #H = torch.autograd.functional.hessian(loss, x)
    Traj[:, eps] = x.detach()
    return x, PositiveHits / eps, Traj, H, AccCov