import torch
import torch.nn as nn
from torch.distributions.multivariate_normal import MultivariateNormal

class Rayleigh_Quotient(nn.Module):

    def __init__(self, mat,  noise_scale=4.0, initial=10.0):
        self.dimension =mat.size(0)
        '''
        self.weight= torch.zeros(self.dimension).reshape(self.dimension, 1)
        self.weight[0,0] = 10
        self.weight[1,0] = 10
        '''
        self.weight = torch.ones(self.dimension).reshape(self.dimension, 1)
        self.weight[0,0] = 10
        self.weight.requires_grad=False
        wn = torch.norm(self.weight)
        self.weight = self.weight * initial / wn
        self.A = mat
        self.noise = MultivariateNormal(torch.zeros(self.dimension), noise_scale * torch.eye(self.dimension))
        self.true_grad = torch.zeros_like(self.weight)
        self.noisy_grad = torch.zeros_like(self.weight)
        self.loss = torch.zeros(1)
        self.term_1 = torch.zeros(1)
        self.term_2 = torch.zeros(1)
        self.ratio = torch.zeros(1)

    @torch.no_grad()
    def forward(self):
        self.loss = 0.5 * torch.matmul(torch.matmul(self.weight.T, self.A), self.weight)/ ((self.weight*self.weight).sum())
        return self.loss.item()


    def cal_grad(self):
        proj = torch.eye(self.dimension) - torch.matmul(self.weight, self.weight.T) / ((self.weight*self.weight).sum())
        res = torch.matmul(self.A, self.weight)
        self.true_grad = torch.matmul(proj , res) / ((self.weight*self.weight).sum())
        self.noisy_grad = torch.matmul(proj, res + torch.norm(self.weight)*self.noise.sample().reshape(self.dimension,1))  / ((self.weight*self.weight).sum())

        return [(self.true_grad).clone().detach(), (self.noisy_grad).clone().detach()]
