import torch.nn as nn
import torch
import torch.nn.functional as F

# Some code from https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py


class RouteNorm(nn.Module):
    def __init__(self, num_features, affine_size=64, mode='pool', eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super(RouteNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.mode=mode
        self.momentum = momentum
        self.affine = affine
        self.batch_size = -1
        self.track_running_stats = track_running_stats
        if self.affine:
            self.affine_size = affine_size # typically batch_size
            self.mean_su = nn.Linear(self.affine_size, 1)
            self.var_su = nn.Linear(self.affine_size, 1)
        else:
            self.register_parameter('mean_su', None)
            self.register_parameter('var_su', None)

        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            nn.init.uniform_(self.mean_su.weight)
            nn.init.constant_(self.var_su.weight, 0.0)

    def forward(self, x, ref_x=None):
        if x.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(x.dim()))
        if self.batch_size < 0:
            self.batch_size = x.size(0)

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1

        # Prepare Reference feature
        if ref_x is None:
            ref_x = x
        ref_y = ref_x.transpose(0, 1)
        inchannel = ref_x.size(1)
        if self.affine and ref_x.size(0) < self.batch_size:
            # Final batch, contains very few examples in this batch
            ref_y = torch.cat([ref_x,
                                torch.zeros(self.batch_size - ref_x.size(0), ref_x.size(1), ref_x.size(2),
                                            ref_x.size(3), dtype=ref_x.dtype,
                                            device=ref_x.device)], dim=0)

        if self.mode == 'pool':
            ref_y = F.avg_pool2d(ref_y, ref_y.size(2)).view(-1, ref_y.size(1)) # nChannels x Batch Size
            new_size = self.batch_size
        else:

            ref_y = ref_y.contiguous().view(inchannel, -1) # nChannels x affine_size
            new_size = self.batch_size * x.size(2) * x.size(3)
            if ref_y.size(1)<new_size: # TODO: Best way to create mean_su with prealloted values
                ref_y = torch.cat([ref_y, torch.zeros(ref_y.size(0), new_size - ref_y.size(1), dtype=ref_y.dtype,
                                                      device=ref_y.device)], dim=1)
            if ref_y.size(1) > new_size:
                raise ValueError(
                    'Invalid Input size for RouteNorm mode.'
                    + ' Please change the src_chan_vec in the ShareableResidualLayer.')
        # print(new_size)

        y = x.transpose(0,1)
        return_shape = y.shape
        y = y.contiguous().view(x.size(1), -1)
        mu = y.mean(dim=1)
        sigma2 = y.var(dim=1)
        if self.training is not True:
            y = y - self.running_mean.view(-1, 1)
            y = y / (self.running_var.view(-1, 1)**.5 + self.eps)
        else:
            if self.track_running_stats is True:
                with torch.no_grad():
                    self.running_mean = (1-self.momentum)*self.running_mean + self.momentum*mu
                    self.running_var = (1-self.momentum)*self.running_var + self.momentum*sigma2
            y = y - mu.view(-1,1)
            y = y / (sigma2.view(-1,1)**.5 + self.eps)

        if self.affine:
            # print(self.mean_su.weight.mean(),self.var_su.weight.mean())
            y = (self.mean_su(ref_y)).view(-1, 1) * y + (self.var_su(ref_y)).view(-1, 1) * ref_x.transpose(0,1).contiguous().view(
                inchannel, -1)
        return y.view(return_shape).transpose(0,1)


class RouteNormOther(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super(RouteNormOther, self).__init__(num_features, eps, momentum, affine, track_running_stats)

    def forward(self, x, ref_x=None):
        if x.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(x.dim()))
        new_size = x.size(0)*x.size(2)*x.size(3)
        if self.weight.size(0) != new_size:
            self.weight = nn.Parameter(torch.zeros(new_size,dtype=x.dtype,device=x.device))
            self.bias = nn.Parameter(torch.zeros(new_size,dtype=x.dtype,device=x.device))
            self.batch_size = x.size(0)
            self.affine = True
            self.reset_parameters()
        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
        if ref_x is None:
            ref_x = x

        if self.affine and ref_x.size(0) < self.batch_size:
            # Final batch, contains very few examples in this batch
            ref_x = torch.cat([ref_x,
                                torch.zeros(self.batch_size - ref_x.size(0), ref_x.size(1), ref_x.size(2),
                                            ref_x.size(3), dtype=ref_x.dtype,
                                            device=ref_x.device)], dim=0)
        y = x.transpose(0,1)
        ref_y = ref_x.transpose(0, 1)
        return_shape = y.shape
        y = y.contiguous().view(x.size(1), -1)
        ref_y = ref_y.contiguous().view(ref_x.size(1), -1)
        mu = y.mean(dim=1)
        sigma2 = y.var(dim=1)
        if self.training is not True:
            y = y - self.running_mean.view(-1, 1)
            y = y / (self.running_var.view(-1, 1)**.5 + self.eps)
        else:
            if self.track_running_stats is True:
                with torch.no_grad():
                    self.running_mean = (1-self.momentum)*self.running_mean + self.momentum*mu
                    self.running_var = (1-self.momentum)*self.running_var + self.momentum*sigma2
            y = y - mu.view(-1,1)
            y = y / (sigma2.view(-1,1)**.5 + self.eps)

        y = torch.mv(ref_y,self.weight).view(-1, 1) * y + torch.mv(ref_y,self.bias).view(-1, 1)
        return y.view(return_shape).transpose(0, 1)

class RouteNormlikeBN(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super(RouteNormlikeBN, self).__init__(num_features, eps, momentum, affine, track_running_stats)

    def forward(self, x, ref_x=None):
        if x.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(x.dim()))
        new_size = x.size(0)
        if self.weight.size(0) != new_size:
            self.weight = nn.Parameter(torch.zeros(new_size,dtype=x.dtype,device=x.device))
            self.bias = nn.Parameter(torch.zeros(new_size,dtype=x.dtype,device=x.device))
            self.affine = True
            self.affine_size = x.size(0)
            self.reset_parameters()

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
        if ref_x is None:
            ref_x = x
        ref_x = F.avg_pool2d(ref_x, ref_x.size(2)).view(-1, ref_x.size(1))
        ref_pr = ref_x.transpose(0,1)
        y = x.transpose(0,1)
        return_shape = y.shape
        y = y.contiguous().view(x.size(1), -1)
        mu = y.mean(dim=1)
        sigma2 = y.var(dim=1)
        if self.training is not True:
            y = y - self.running_mean.view(-1, 1)
            y = y / (self.running_var.view(-1, 1)**.5 + self.eps)
        else:
            if self.track_running_stats is True:
                with torch.no_grad():
                    self.running_mean = (1-self.momentum)*self.running_mean + self.momentum*mu
                    self.running_var = (1-self.momentum)*self.running_var + self.momentum*sigma2
            y = y - mu.view(-1,1)
            y = y / (sigma2.view(-1,1)**.5 + self.eps)

        if self.affine and ref_pr.size(1) < self.affine_size:
            # Final batch, contains very few examples in this batch
            ref_pr = torch.cat([ref_pr,
                                torch.zeros(ref_pr.size(0), self.affine_size - ref_pr.size(1), dtype=ref_pr.dtype,
                                            device=ref_pr.device)], dim=1)
        y = torch.mv(ref_pr,self.weight).view(-1, 1) * y + torch.mv(ref_pr,self.bias).view(-1, 1)
        return y.view(return_shape).transpose(0, 1)

class RouteNormFull(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False, track_running_stats=True):
        super(RouteNormFull, self).__init__(num_features, eps, momentum, affine, track_running_stats)
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.affine_size = 0
        self.track_running_stats = track_running_stats
        # if self.affine:
        #     self.mean_su = nn.Linear(self.affine_size, 1)
        #     self.var_su = nn.Linear(self.affine_size, 1)

        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_running_stats()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def reset_parameters(self, input=None):
        self.reset_running_stats()
        if self.affine:
            if input is not None:
                self.affine_size = input.size(0)
                self.mean_su = nn.Linear(self.affine_size, 1)
                self.var_su = nn.Linear(self.affine_size, 1)

            nn.init.uniform_(self.mean_su.weight)
            nn.init.constant_(self.var_su.weight, 0.0)

    def forward(self, x, ref_x=None):
        if x.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(x.dim()))
        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
        if not self.affine:
            self.affine = True
            self.reset_parameters(x)

        if ref_x is None:
            ref_x = x
        ref_x = F.avg_pool2d(ref_x, ref_x.size(2)).view(-1, ref_x.size(1))
        ref_pr = ref_x.transpose(0,1)
        y = x.transpose(0,1)
        return_shape = y.shape
        y = y.contiguous().view(x.size(1), -1)
        mu = y.mean(dim=1)
        sigma2 = y.var(dim=1)
        if self.training is not True:
            y = y - self.running_mean.view(-1, 1)
            y = y / (self.running_var.view(-1, 1)**.5 + self.eps)
        else:
            if self.track_running_stats is True:
                with torch.no_grad():
                    self.running_mean = (1-self.momentum)*self.running_mean + self.momentum*mu
                    self.running_var = (1-self.momentum)*self.running_var + self.momentum*sigma2
            y = y - mu.view(-1,1)
            y = y / (sigma2.view(-1,1)**.5 + self.eps)

        if self.affine and ref_pr.size(1) < self.affine_size:
            # Final batch, contains very few examples in this batch
            ref_pr = torch.cat([ref_pr,
                                torch.zeros(ref_pr.size(0), self.affine_size - ref_pr.size(1), dtype=ref_pr.dtype,
                                            device=ref_pr.device)], dim=1)

        # y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)
        # self.wt = F.relu6(torch.stack([self.mean_su(ref_pr), self.var_su(ref_pr)]))
        if self.affine:
            y = (self.mean_su(ref_pr)).view(-1, 1) * y + (self.var_su(ref_pr)).view(-1, 1)
        return y.view(return_shape).transpose(0,1)

class RouteNormOld(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False, track_running_stats=True):
        super(RouteNormOld, self).__init__(num_features, eps, momentum, affine, track_running_stats)
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.affine_size = 0
        self.track_running_stats = track_running_stats
        # if self.affine:
        #     self.mean_su = nn.Linear(self.affine_size, 1)
        #     self.var_su = nn.Linear(self.affine_size, 1)

        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_running_stats()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
            self.num_batches_tracked.zero_()

    def reset_parameters(self, input=None):
        self.reset_running_stats()
        if self.affine:
            if input is not None:
                self.affine_size = input.size(0)
                self.mean_su = nn.Linear(self.affine_size, 1)
                self.var_su = nn.Linear(self.affine_size, 1)

            nn.init.uniform_(self.mean_su.weight)
            nn.init.constant_(self.var_su.weight, 0.0)

    def forward(self, x, ref_x=None):
        if x.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(x.dim()))
        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
        if not self.affine:
            self.affine = True
            self.reset_parameters(x)

        if ref_x is None:
            ref_x = x
        ref_x = F.avg_pool2d(ref_x, ref_x.size(2)).view(-1, ref_x.size(1))
        ref_pr = ref_x.transpose(0,1)
        y = x.transpose(0,1)
        return_shape = y.shape
        y = y.contiguous().view(x.size(1), -1)
        mu = y.mean(dim=1)
        sigma2 = y.var(dim=1)
        if self.training is not True:
            y = y - self.running_mean.view(-1, 1)
            y = y / (self.running_var.view(-1, 1)**.5 + self.eps)
        else:
            if self.track_running_stats is True:
                with torch.no_grad():
                    self.running_mean = (1-self.momentum)*self.running_mean + self.momentum*mu
                    self.running_var = (1-self.momentum)*self.running_var + self.momentum*sigma2
            y = y - mu.view(-1,1)
            y = y / (sigma2.view(-1,1)**.5 + self.eps)

        if self.affine and ref_pr.size(1) < self.affine_size:
            # Final batch, contains very few examples in this batch
            ref_pr = torch.cat([ref_pr,
                                torch.zeros(ref_pr.size(0), self.affine_size - ref_pr.size(1), dtype=ref_pr.dtype,
                                            device=ref_pr.device)], dim=1)

        # y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)
        # self.wt = F.relu6(torch.stack([self.mean_su(ref_pr), self.var_su(ref_pr)]))
        if self.affine:
            y = (self.mean_su(ref_pr)).view(-1, 1) * y + (self.var_su(ref_pr)).view(-1, 1)
        return y.view(return_shape).transpose(0,1)

def compare_bn(bn1, bn2):
    err = False
    if not torch.allclose(bn1.running_mean, bn2.running_mean):
        print('Diff in running_mean: {} vs {}'.format(
            bn1.running_mean, bn2.running_mean))
        err = True

    if not torch.allclose(bn1.running_var, bn2.running_var):
        print('Diff in running_var: {} vs {}'.format(
            bn1.running_var, bn2.running_var))
        err = True

    if bn1.affine and bn2.affine:
        if not torch.allclose(bn1.weight, bn2.weight):
            print('Diff in weight: {} vs {}'.format(
                bn1.weight, bn2.weight))
            err = True

        if not torch.allclose(bn1.bias, bn2.bias):
            print('Diff in bias: {} vs {}'.format(
                bn1.bias, bn2.bias))
            err = True

    if not err:
        print('All parameters are equal!')

def test_route_norm():
    # Init BatchNorm layers
    my_bn = RouteNorm(3, affine=True)
    bn = RouteNorm(3, affine=True)

    compare_bn(my_bn, bn)  # weight and bias should be different
    # Load weight and bias
    my_bn.load_state_dict(bn.state_dict())
    compare_bn(my_bn, bn)

    # Run train
    for _ in range(10):
        scale = torch.randint(1, 10, (1,)).float()
        bias = torch.randint(-10, 10, (1,)).float()
        x = torch.randn(10, 3, 100, 100) * scale + bias
        out1 = my_bn(x)
        out2 = bn(x)
        compare_bn(my_bn, bn)

        torch.allclose(out1, out2)
        print('Max diff: ', (out1 - out2).abs().max())

    # Run eval
    my_bn.eval()
    bn.eval()
    for _ in range(10):
        scale = torch.randint(1, 10, (1,)).float()
        bias = torch.randint(-10, 10, (1,)).float()
        x = torch.randn(10, 3, 100, 100) * scale + bias
        out1 = my_bn(x)
        out2 = bn(x)
        compare_bn(my_bn, bn)

        torch.allclose(out1, out2)
        print('Max diff: ', (out1 - out2).abs().max())


#test_route_norm()
