import torch
from torch import nn
import numpy as np
from torch.nn import init
import torch.nn.functional as F

import math


class FlowPickNet(nn.Module):
    def __init__(self, inchannels, im_w, outchannels=2, use_tanh=False, use_pool=False):
        super(FlowPickNet, self).__init__()
        self.trunk = nn.Sequential(nn.Conv2d(inchannels, 32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 1),
                                    nn.ReLU(True))
        self.head  = nn.Sequential(#nn.UpsamplingBilinear2d(scale_factor=2),
                                    nn.Conv2d(32,32, 3, 1),
                                    nn.ReLU(True),
                                    nn.UpsamplingBilinear2d(scale_factor=2),
                                    nn.Conv2d(32,outchannels, 3, 1))

        self.im_w = im_w
        self.use_tanh = use_tanh
        if use_tanh:
            self.tanh = nn.Tanh()
        self.use_pool = use_pool
        if use_pool:
            self.pool = nn.AvgPool2d(kernel_size = (20,20))

    def forward(self, x):
        x = self.trunk(x)
        out = self.head(x)
        #print(out.shape)
        if self.use_tanh:
            out = self.tanh(out)
        out = nn.Upsample(size=(20,20), mode="bilinear").forward(out)
        if self.use_pool:
            out = self.pool(out)

        return out



class FlowPickSplit(nn.Module):
    def __init__(self, inchannels, im_w, second=False):
        super(FlowPickSplit, self).__init__()
        self.trunk = nn.Sequential(nn.Conv2d(inchannels, 32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 1),
                                    nn.ReLU(True))
        self.head  = nn.Sequential(#nn.UpsamplingBilinear2d(scale_factor=2),
                                    nn.Conv2d(32,32, 3, 1),
                                    nn.ReLU(True),
                                    nn.UpsamplingBilinear2d(scale_factor=2),
                                    nn.Conv2d(32,1, 3, 1))

        self.im_w = im_w
        self.second = second
        self.upsample = nn.Upsample(size=(20,20), mode="bilinear")

    #def forward(self, obs, goal, flow, gaus=None):
    def forward(self, x):

        # if self.second:
        #     #x = torch.cat([obs, goal, flow, gaus], dim=1)
        #     x = torch.cat([flow, gaus], dim=1)
        # else:
        #     #x = torch.cat([obs, goal, flow], dim=1)
        #     x = flow
        x = self.trunk(x)
        out = self.head(x)
        #print(out.shape)
        #out = nn.Upsample(size=(20,20), mode="bilinear").forward(out)
        out = self.upsample(out)

        return out


# --------------------------------------------------------
""" FlowNetSmall. Learned flow model """

def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
    if batchNorm:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.LeakyReLU(0.1,inplace=True)
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.LeakyReLU(0.1,inplace=True)
        )

def predict_flow(in_planes):
    return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True)

def deconv(in_planes, out_planes, ksize=3):
    return nn.Sequential(
        # nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=ksize, stride=2, padding=1, bias=True),
        nn.LeakyReLU(0.1,inplace=True)
    )

class FlowNetSmall(nn.Module):
    def __init__(self, input_channels = 12, batchNorm=True):
        super(FlowNetSmall,self).__init__()

        fs = [8, 16, 32, 64, 128] # filter sizes
        # fs = [16, 32, 64, 128, 256] # filter sizes
        # fs = [64, 128, 256, 512, 1024] # filter sizes
        self.batchNorm = batchNorm
        self.conv1   = conv(self.batchNorm, input_channels, fs[0], kernel_size=7, stride=2) # 384 -> (384 - 7 + 2*3)/2 + 1 = 377
        self.conv2   = conv(self.batchNorm, fs[0], fs[1], kernel_size=5, stride=2)
        self.conv3   = conv(self.batchNorm, fs[1], fs[2], kernel_size=5, stride=2)
        self.conv3_1 = conv(self.batchNorm, fs[2], fs[2])
        self.conv4   = conv(self.batchNorm, fs[2], fs[3], stride=2)
        self.conv4_1 = conv(self.batchNorm, fs[3], fs[3])
        self.conv5   = conv(self.batchNorm, fs[3], fs[3], stride=2)
        self.conv5_1 = conv(self.batchNorm, fs[3], fs[3])
        self.conv6   = conv(self.batchNorm, fs[3], fs[4], stride=2)
        self.conv6_1 = conv(self.batchNorm, fs[4], fs[4])

        self.deconv5 = deconv(fs[4],fs[3])
        self.deconv4 = deconv(fs[3]+fs[3]+2,fs[2])
        self.deconv3 = deconv(fs[3]+fs[2]+2,fs[1])
        self.deconv2 = deconv(fs[2]+fs[1]+2,fs[0], ksize=4)
        # self.deconv5 = deconv(1024,512)
        # self.deconv4 = deconv(1026,256)
        # self.deconv3 = deconv(770,128)
        # self.deconv2 = deconv(386,64, ksize=4)

        self.predict_flow6 = predict_flow(fs[4])
        self.predict_flow5 = predict_flow(fs[3]+fs[3]+2)
        self.predict_flow4 = predict_flow(fs[3]+fs[2]+2)
        self.predict_flow3 = predict_flow(fs[2]+fs[1]+2)
        self.predict_flow2 = predict_flow(fs[1]+fs[0]+2)

        # self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) # (H_in-1)*stride - 2*padding + (kernel-1) + 1
        # self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        # self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
        # self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)

        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 3, 2, 1, bias=False) # (H_in-1)*stride - 2*padding + (kernel-1) + 1
        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 3, 2, 1, bias=False)
        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 3, 2, 1, bias=False)
        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.bias is not None:
                    init.uniform_(m.bias)
                init.xavier_uniform_(m.weight)

            if isinstance(m, nn.ConvTranspose2d):
                if m.bias is not None:
                    init.uniform_(m.bias)
                init.xavier_uniform_(m.weight)
                # init_deconv_bilinear(m.weight)
        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')

    def forward(self, x):
        out_conv1 = self.conv1(x)

        out_conv2 = self.conv2(out_conv1)
        out_conv3 = self.conv3_1(self.conv3(out_conv2))
        out_conv4 = self.conv4_1(self.conv4(out_conv3))
        out_conv5 = self.conv5_1(self.conv5(out_conv4))
        out_conv6 = self.conv6_1(self.conv6(out_conv5))

        flow6       = self.predict_flow6(out_conv6)
        flow6_up    = self.upsampled_flow6_to_5(flow6)
        out_deconv5 = self.deconv5(out_conv6)

        # print(out_conv5.size(), out_deconv5.size(), flow6_up.size())      

        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
        flow5       = self.predict_flow5(concat5)
        flow5_up    = self.upsampled_flow5_to_4(flow5)
        out_deconv4 = self.deconv4(concat5)

        # print(out_conv4.size(), out_deconv4.size(), flow5_up.size())
        
        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
        flow4       = self.predict_flow4(concat4)
        flow4_up    = self.upsampled_flow4_to_3(flow4)
        out_deconv3 = self.deconv3(concat4)

        # print(out_conv3.size(), out_deconv3.size(), flow4_up.size())

        concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
        flow3       = self.predict_flow3(concat3)
        flow3_up    = self.upsampled_flow3_to_2(flow3)
        out_deconv2 = self.deconv2(concat3)

        # print(out_conv2.size(), out_deconv2.size(), flow3_up.size())

        concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
        flow2 = self.predict_flow2(concat2)

        # print(flow2.size())

        # if self.training:
            # return flow2,flow3,flow4,flow5,flow6
        # else:

        out = self.upsample1(flow2)

        return out