
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from utils.trans_norm import TransNorm2d


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class TestBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(TestBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x, ref_x):
        residual = x
        x = torch.cat([x, ref_x], 1)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class SampledAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature):
        super(SampledAttention,self).__init__()
        self.temperature = temperature

    def forward(self, q, k, v):
        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature
        probs = F.softmax(attn, 2)
        probs = probs.squeeze(1)
        indices = probs.multinomial(num_samples=1)
        output = torch.gather(v, 1, indices.unsqueeze(2).repeat(1, 1, v.size(-1)))
        return output, attn


class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, dropout=0.1):
        super(ScaledDotProductAttention,self).__init__()
        self.temperature = temperature
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, q, k, v):
        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature
        attn = F.softmax(attn, 2)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        return output, attn


class AttentionBasedTransfer(nn.Module):
    def __init__(self,tg_layer_id, block_hidden_dim, n_head=1,type='avg',dropout=0.0):
        super(AttentionBasedTransfer, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.block_hidden_dim = block_hidden_dim*block_hidden_dim
        self.in_dim = round(block_hidden_dim/n_head)
        self.n_head = n_head

        self.w_qs = torch.nn.Linear(self.block_hidden_dim, n_head * self.in_dim, bias=False)
        self.w_ks = torch.nn.Linear(self.block_hidden_dim, n_head * self.in_dim, bias=False)
        self.w_vs = torch.nn.Linear(self.block_hidden_dim, n_head * self.in_dim, bias=False)

        torch.nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (self.block_hidden_dim * 2)))
        torch.nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (self.block_hidden_dim * 2)))
        torch.nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (self.block_hidden_dim * 2)))
        if type == 'avg':
            self.attention = ScaledDotProductAttention(temperature=np.power(self.in_dim, 0.5), dropout=dropout)
        else:
            self.attention = SampledAttention(temperature=np.power(self.in_dim, 0.5))

        self.fc = torch.nn.Linear(n_head * self.in_dim, self.block_hidden_dim)
        self.layer_norm = torch.nn.LayerNorm(block_hidden_dim,block_hidden_dim)
        torch.nn.init.xavier_normal_(self.fc.weight)
        self.dropout = torch.nn.Dropout(dropout)
        self.attn = None

        # self.wt = nn.Parameter(torch.ones(2) * 0.5)
        # self.bn = nn.BatchNorm2d(planes)

    def forward(self, input1, ref_inputs):
        # add sentinel item to ref_inputs to account for no attention to any items
        batch_size = input1.size(0)
        nchannel = input1.size(1)
        context = torch.stack(ref_inputs,dim=2).view(batch_size*nchannel,len(ref_inputs),self.block_hidden_dim)
        query = input1.view(batch_size*nchannel,1,self.block_hidden_dim)
        # attention over the ref_inputs
        q = self.w_qs(query).view(batch_size*nchannel,1, self.n_head, self.in_dim) # (b*c) x lq x n x dq
        k = self.w_ks(context).view(batch_size*nchannel,len(ref_inputs), self.n_head, self.in_dim) # (b*c) x lq x n x dk
        v = self.w_vs(context).view(batch_size*nchannel,len(ref_inputs), self.n_head, self.in_dim) # (b*c) x lq x n x dv

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, 1, self.in_dim)  # (n*b*c) x lq x dq
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len(ref_inputs), self.in_dim)  # (n*b*c) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len(ref_inputs), self.in_dim)  # (n*b*c) x lv x dv

        input2, attn = self.attention(q, k, v)
        self.attn = attn
        input2 = input2.view(self.n_head, batch_size*nchannel, 1,self.in_dim)
        input2 = input2.permute(1, 2, 0, 3).contiguous().view(batch_size*nchannel,1, -1)  # b x lq x (n*dv)
        input2 = self.dropout(self.fc(input2))
        input2 = input2.view(batch_size,nchannel, int(np.sqrt(self.block_hidden_dim)),int(np.sqrt(self.block_hidden_dim)))
        # input2 = val.mean(-1)
        # y = self.wt[0] * input1 + self.wt[1] * input2
        # y=self.bn(y)
        y = input1 + input2
        y = self.layer_norm(y)
        return y


class Transnorm(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, planes):
        super(Transnorm, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.tn = TransNorm2d(planes)

    def forward(self, input1, input2):
        batch_size = input1.size()[0]
        input = torch.cat([input1,input2],dim=0)
        input = self.tn(input)
        y = input[:batch_size]
        return y


class CrossStitch(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, planes):
        super(CrossStitch, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.wt = nn.Parameter(torch.ones(2) * 0.5)
        # self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, input1, input2, relu=False):
        assert input1.shape == input2.shape
        y = self.wt[0] * input1 + self.wt[1] * input2
        # y=self.bn(y)
        if relu:
            y = self.relu(y)

        return y


class LinearXStitch(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, tg_feat_size=48, src_feat_size=48):
        super(LinearXStitch, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.su = nn.Linear(src_feat_size, 1)
        self.tu = nn.Linear(tg_feat_size, 1)
        nn.init.constant_(self.su.bias, 1.0)
        nn.init.constant_(self.tu.bias, 1.0)

    def forward(self, input1, input2):
        f1 = F.avg_pool2d(input1, input1.size(2)).view(-1, input1.size(1))
        f2 = F.avg_pool2d(input2, input2.size(2)).view(-1, input2.size(1))
        y = self.su(f2).reshape(input1.size(0), 1, 1, 1) * input1 + self.tu(f1).reshape(input2.size(0), 1, 1, 1) * input2
        return y


class LinearStitch(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, tg_feat_size=48, src_feat_size=48):
        super(LinearStitch, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.su = nn.Linear(src_feat_size, 1)
        self.tu = nn.Linear(tg_feat_size, 1)
        nn.init.constant_(self.su.bias, 1.0)
        nn.init.constant_(self.tu.bias, 1.0)

    def forward(self, input1, input2):
        f1 = F.avg_pool2d(input1, input1.size(2)).view(-1, input1.size(1))
        f2 = F.avg_pool2d(input2, input2.size(2)).view(-1, input2.size(1))
        y = self.tu(f1).reshape(input1.size(0), 1, 1, 1) * input1 + self.su(f2).reshape(input2.size(0), 1, 1, 1) * input2
        return y