from torch import nn
import torch
from torch import nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import random

def seed_libraries(random_seed):
	random.seed(random_seed + 1)
	np.random.seed(random_seed + 2)
	torch.manual_seed(random_seed + 3)
	torch.cuda.manual_seed_all(random_seed + 4)
	torch.cuda.manual_seed(random_seed + 5)

class GradWeight(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weights):
        ctx.weights = weights
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
            return (grad_output * ctx.weights), None, None
            
class Normalize(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.register_buffer('mean', mean)
        self.register_buffer('std', std)

    def forward(self, inputs):
        return (inputs - self.mean) / self.std

    def update(self, count, value_new, new_samples_count):
        """
        Algorithm: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
        """
        count_old = count - new_samples_count
        if count_old == 0:
            self.mean = value_new
            return

        if count_old == 1:
            self.std = torch.cuda.FloatTensor(self.std.shape).fill_(0).cuda()
        delta_before_update = value_new - self.mean
        self.mean += delta_before_update * new_samples_count / count
        delta_after_update = value_new - self.mean
        total_variance = (self.std ** 2) * count_old + delta_before_update * delta_after_update * new_samples_count
        self.std = (total_variance / count).sqrt()

    def denormalize(self, inputs):
        return inputs * self.std + self.mean

def set_requires_grad(model, requires_grad=True):
    for param in model.parameters():
        param.requires_grad = requires_grad


def loop_iterable(iterable):
    while True:
        yield from iterable


# class GrayscaleToRgb:
#     """Convert a grayscale image to rgb"""
#     def __call__(self, image):
#         image = np.array(image)
#         image = np.dstack([image, image, image])
#         return Image.fromarray(image)


# class GradientReversalFunction(Function):
#     """
#     Gradient Reversal Layer from:
#     Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
#     Forward pass is the identity function. In the backward pass,
#     the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
#     """

#     @staticmethod
#     def forward(ctx, x, lambda_):
#         ctx.lambda_ = lambda_
#         return x.clone()

#     @staticmethod
#     def backward(ctx, grads):
#         lambda_ = ctx.lambda_
#         lambda_ = grads.new_tensor(lambda_)
#         dx = -lambda_ * grads
#         return dx, None


# class GradientReversal(torch.nn.Module):
#     def __init__(self, lambda_=1):
#         super(GradientReversal, self).__init__()
#         self.lambda_ = lambda_

#     def forward(self, x):
#         return GradientReversalFunction.apply(x, self.lambda_)