from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import math
import os
from os import path
from os.path import join as pjoin
from matplotlib import image
from matplotlib.pyplot import cla

import torch
import torch.nn as nn
import numpy as np

from DualGCN_modules import DualGCN_parallel, fu_DualGCN_Spatial_fist
from Unet import UNet
import torchvision
from networks import *
from model import *

class gnet(nn.Module):
    def __init__(self):
        super(gnet, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(6, 16, 3, 1, 1),
            nn.ReLU())

        self.conve = nn.Sequential(
            nn.Conv2d(16, 3, 3, 1, 1),
            nn.ReLU())

    def forward(self, input):
        x = input
        x1 = self.conv(x)
        x2 = self.conve(x1)
        return x2

class f_net(nn.Module):
    def __init__(self, inchannels=3, interchannels=64, outchannels=3):
        super(f_net, self).__init__()
        self.conv_en0 = nn.Sequential(nn.Conv2d(inchannels, interchannels, 3, padding=1, bias=False),
                                      nn.BatchNorm2d(interchannels),
                                      nn.ReLU(interchannels))
        self.conv_en1 = nn.Sequential(nn.Conv2d(interchannels, interchannels, 3, padding=1, bias=False),
                                      nn.BatchNorm2d(interchannels),
                                      nn.ReLU(interchannels))

        self.encode0 = DualGCN_parallel(interchannels, interchannels // 4, interchannels)
        self.encode1 = DualGCN_parallel(interchannels, interchannels // 4, interchannels)
        self.encode2 = DualGCN_parallel(interchannels, interchannels // 4, interchannels)
        self.encode3 = DualGCN_parallel(interchannels, interchannels // 4, interchannels)
        self.encode4 = DualGCN_parallel(interchannels, interchannels // 4, interchannels)
        self.middle = DualGCN_parallel(interchannels, interchannels // 4, interchannels)

        self.conv1x1_4 = nn.Conv2d(interchannels * 2, interchannels, 1, bias=False)
        self.decode4 = DualGCN_parallel(interchannels, interchannels // 4, interchannels)
        self.conv1x1_3 = nn.Conv2d(interchannels * 2, interchannels, 1, bias=False)
        self.decode3 = DualGCN_parallel(interchannels, interchannels // 4, interchannels)
        self.conv1x1_2 = nn.Conv2d(interchannels * 2, interchannels, 1, bias=False)
        self.decode2 = DualGCN_parallel(interchannels, interchannels // 4, interchannels)
        self.conv1x1_1 = nn.Conv2d(interchannels * 2, interchannels, 1, bias=False)
        self.decode1 = DualGCN_parallel(interchannels, interchannels // 4, interchannels)
        self.conv1x1_0 = nn.Conv2d(interchannels * 2, interchannels, 1, bias=False)
        self.decode0 = DualGCN_parallel(interchannels, interchannels // 4, interchannels)

        # self.encode0 = fu_DualGCN_Spatial_fist(interchannels)
        # self.encode1 = fu_DualGCN_Spatial_fist(interchannels)
        # self.encode2 = fu_DualGCN_Spatial_fist(interchannels)
        # self.encode3 = fu_DualGCN_Spatial_fist(interchannels)
        # self.encode4 = fu_DualGCN_Spatial_fist(interchannels)
        # self.middle = fu_DualGCN_Spatial_fist(interchannels)

        # self.conv1x1_4 = nn.Conv2d(interchannels * 2 , interchannels, 1, bias=False)
        # self.decode4 = fu_DualGCN_Spatial_fist(interchannels)
        # self.conv1x1_3 = nn.Conv2d(interchannels * 2 , interchannels, 1, bias=False)
        # self.decode3 = fu_DualGCN_Spatial_fist(interchannels)
        # self.conv1x1_2 = nn.Conv2d(interchannels * 2 , interchannels, 1, bias=False)
        # self.decode2 = fu_DualGCN_Spatial_fist(interchannels)
        # self.conv1x1_1 = nn.Conv2d(interchannels * 2 , interchannels, 1, bias=False)
        # self.decode1 = fu_DualGCN_Spatial_fist(interchannels)
        # self.conv1x1_0 = nn.Conv2d(interchannels * 2 , interchannels, 1, bias=False)
        # self.decode0 = fu_DualGCN_Spatial_fist(interchannels)

        self.conv_de0 = nn.Sequential(nn.Conv2d(interchannels, interchannels, 3, padding=1, bias=False),
                                      nn.BatchNorm2d(interchannels),
                                      nn.ReLU(interchannels))
        self.conv_de1 = nn.Sequential(nn.Conv2d(interchannels, outchannels, 3, padding=1, bias=False),
                                      nn.BatchNorm2d(outchannels),
                                      nn.ReLU(outchannels))

    def forward(self, x):
        basic_feat0 = self.conv_en0(x)
        basic_feat1 = self.conv_en1(basic_feat0)

        encoder0 = self.encode0(basic_feat1)
        encoder1 = self.encode1(encoder0)
        encoder2 = self.encode2(encoder1)
        encoder3 = self.encode3(encoder2)
        encoder4 = self.encode4(encoder3)

        middle = self.middle(encoder4)

        decoder4 = self.conv1x1_4(torch.cat([middle, encoder4], dim=1))
        decoder4 = self.decode4(decoder4)

        decoder3 = self.conv1x1_3(torch.cat([decoder4, encoder3], dim=1))
        decoder3 = self.decode3(decoder3)

        decoder2 = self.conv1x1_2(torch.cat([decoder3, encoder2], dim=1))
        decoder2 = self.decode2(decoder2)

        decoder1 = self.conv1x1_1(torch.cat([decoder2, encoder1], dim=1))
        decoder1 = self.decode1(decoder1)

        decoder0 = self.conv1x1_0(torch.cat([decoder1, encoder0], dim=1))
        decoder0 = self.decode0(decoder0)

        decode_end0 = self.conv_de0(decoder0)
        decode_end1 = self.conv_de1(decode_end0)

        return decode_end1 + x


class TL_net(nn.Module):
    def __init__(self, tf_order):
        super(TL_net, self).__init__()
        #self.f_model = f_net(inchannels=3, interchannels=64, outchannels=3)

        #self.g_model = UNet(in_channels=6, num_classes=3, init_features=16)

        self.f_model = Net(channels=72)
        self.g_model = gnet()
        self.tl_layer_num = tf_order

    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1, 3, 1, 1)

        f_x = self.f_model(x)
        features = []
        features.append(f_x)
        coeffici = []
        base_co = 1
        coeffici.append(base_co)
        for i in range(self.tl_layer_num):
            base_co *= (i + 1)
            g_in = torch.cat([x, features[-1]], dim=1)
            i_g_feature = self.g_model(g_in)
            i_g_feature = i_g_feature + i * features[-1]
            features.append(i_g_feature)
            coeffici.append(base_co)

        tl_out = torch.zeros_like(f_x)
        #tl_out = 0
        for feature, co in zip(features, coeffici):
            tl_out += feature / co

        return tl_out, f_x