import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.distributions import MultivariateNormal
import math
import random


class LogisticClassification(torch.nn.Module):
    def __init__(self, data, label, batch_size=10):
        """
        b : torch.Tensor
        """
        super(LogisticClassification, self).__init__()

        self.data = data
        self.label = label
        self.batch_size = batch_size
        
        self.linear = nn.Linear(28*28, 10)
        self.criteria = nn.CrossEntropyLoss()

        
    def forward(self, minibatch_data):
        return self.linear(minibatch_data)

    
    def sample_minibatch(self):
        minibatch_index = random.sample(list(range(self.data.shape[0])), self.batch_size)
        return self.data[minibatch_index], self.label[minibatch_index]


    def compute_loss(self):
        """
        set the gradient, and return the loss value.
        """
        minibatch_data, minibatch_label = self.sample_minibatch()
        output = self.forward(minibatch_data)
        return self.criteria(output, minibatch_label)
    
    
    def compute_full_loss(self):
        """
        set the gradient, and return the loss value.
        """
        output = self.forward()
        return self.criteria(output, self.label)


    def get_full_gradient(self, weight, bias):
        """
        return the gradient with stochastic noise.
        """
        new_linear = nn.Linear(28*28, 10)

        with torch.no_grad():
            new_linear.weight.copy_(weight)
            new_linear.bias.copy_(bias)

        new_linear.zero_grad()
        loss = self.criteria(new_linear(self.data), self.label)
        loss.backward()
        
        return loss, torch.cat([new_linear.weight.grad, new_linear.bias.grad.unsqueeze(1)], dim=1)
        
class RidgeClassification(torch.nn.Module):
    def __init__(self, data, label, batch_size=10):
        """
        b : torch.Tensor
        """
        super(RidgeClassification, self).__init__()

        self.data = data
        self.label = label
        self.batch_size = batch_size
        
        self.linear = nn.Linear(28*28, 10)
        self.criteria = nn.MSELoss() #nn.CrossEntropyLoss()

        
    def forward(self, minibatch_data):
        return self.linear(minibatch_data)

    
    def sample_minibatch(self):
        minibatch_index = random.sample(list(range(self.data.shape[0])), self.batch_size)
        return self.data[minibatch_index], self.label[minibatch_index]


    def compute_loss(self):
        """
        set the gradient, and return the loss value.
        """
        minibatch_data, minibatch_label = self.sample_minibatch()
        output = self.forward(minibatch_data)
        return self.criteria(output, F.one_hot(minibatch_label, num_classes=10).float())
    
    
    def compute_full_loss(self):
        """
        set the gradient, and return the loss value.
        """
        output = self.forward()
        return self.criteria(output, F.one_hot(self.label, num_classes=10).float())


    def get_full_gradient(self, weight, bias):
        """
        return the gradient with stochastic noise.
        """
        new_linear = nn.Linear(28*28, 10)

        with torch.no_grad():
            new_linear.weight.copy_(weight)
            new_linear.bias.copy_(bias)

        new_linear.zero_grad()
        loss = self.criteria(new_linear(self.data), F.one_hot(self.label, num_classes=10).float())
        loss.backward()
        
        return loss, torch.cat([new_linear.weight.grad, new_linear.bias.grad.unsqueeze(1)], dim=1)
        
