
import os, sys, math, random, itertools
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.checkpoint import checkpoint

from models import TrainableModel
from utils import *

import pdb


class FiLM_spatial(nn.Module):
    def __init__(self, input_channel, output_channel):
        super().__init__()
        # self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1)
        # # self.conv_alpha = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        # self.conv_beta = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.conv_beta = nn.Conv2d(input_channel, output_channel, 3, padding=1)
        self.relu = torch.nn.ReLU()

    def forward(self, x_size, embed):
        if embed.size(-1) != x_size[-1]:
            embed = F.interpolate(
                embed, 
                size=(x_size[-2], x_size[-1]), 
                mode='bilinear',
                align_corners=False
            )
        # breakpoint()
        # embed = self.relu(self.conv1(embed))
        # alpha = (self.conv_alpha(embed))
        beta = (self.conv_beta(embed))
        print(beta.size())
        return beta
    


class UNet_up_block(nn.Module):
    def __init__(self, prev_channel, input_channel, output_channel, up_sample=True, film_layer=False, embed_channel=None, spatial_film=True):
        super().__init__()
        # print(film_layer)
        self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, output_channel)
        self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, output_channel)
        self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, output_channel)
        if film_layer: 
            if spatial_film:
                self.film = FiLM_spatial(embed_channel, output_channel)
            else:
                self.film = FiLM(embed_channel, output_channel)
        # if film_layer: 
        #     self.film_alpha = nn.Parameter(torch.rand(1,output_channel,1,1)*0.01-0.005)
        #     self.film_beta = nn.Parameter(torch.rand(1,output_channel,1,1)*0.01-0.005)
        #     # self.film_alpha = nn.Parameter(torch.Tensor(1,output_channel,1,1).fill_(0.))
        #     # self.film_beta = nn.Parameter(torch.Tensor(1,output_channel,1,1).fill_(0.))
        self.relu = torch.nn.ReLU()
        self.up_sample = up_sample
        self.film_layer = film_layer
        self.add_beta = nn.Parameter(torch.Tensor(1,output_channel,1,1).fill_(0.))

    def forward(self, prev_feature_map, x, embed=None):
        if self.up_sample:
            x = self.up_sampling(x)
        x = torch.cat((x, prev_feature_map), dim=1)

        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        # if self.film_layer and embed is not None: x = self.film(x, embed)
        if self.film_layer and embed is not None: 
            beta = self.film(x.size(), embed)
            x = x + beta
            # x = x * (1. + alpha) + beta
        x = x + self.add_beta
        x = self.relu(x)

        # if self.film_layer and embed is not None: 
        #     betas = self.film(x.size(), embed)
        #     # alphas, betas = self.film(x.size(), embed)
        # x = (self.bn1(self.conv1(x)))
        # if self.film_layer and embed is not None: 
        #     x = x + betas[0]
        #     # x = x * (1 + alphas[0]) + betas[0]
        # x = self.relu(x)
        # x = (self.bn2(self.conv2(x)))
        # if self.film_layer and embed is not None: 
        #     x = x + betas[1]
        #     # x = x * (1 + alphas[1]) + betas[1]
        # x = self.relu(x)
        # x = self.bn3(self.conv3(x))
        # if self.film_layer and embed is not None: 
        #     x = x + betas[2]
        #     # x = x * (1 + alphas[2]) + betas[2]
        # # if self.film_layer and embed is not None: x = self.film(x, embed)
        # x = self.relu(x)

        return x


class UNet_down_block(nn.Module):
    def __init__(self, input_channel, output_channel, down_size=True, film_layer=False, embed_channel=None, spatial_film=True):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, output_channel)
        self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, output_channel)
        self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, output_channel)
        # if film_layer: self.film = FiLM_spatial(embed_channel, output_channel)
        if film_layer: 
            if spatial_film:
                self.film = FiLM_spatial(embed_channel, output_channel)
            else:
                self.film = FiLM(embed_channel, output_channel)
        # if film_layer: 
        #     self.film_alpha = nn.Parameter(torch.rand(1,output_channel,1,1)*0.01-0.005)
        #     self.film_beta = nn.Parameter(torch.rand(1,output_channel,1,1)*0.01-0.005)
        #     # self.film_alpha = nn.Parameter(torch.Tensor(1,output_channel,1,1).fill_(0.))
        #     # self.film_beta = nn.Parameter(torch.Tensor(1,output_channel,1,1).fill_(0.))
        self.max_pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.down_size = down_size
        self.film_layer = film_layer
        self.add_beta = nn.Parameter(torch.Tensor(1,output_channel,1,1).fill_(0.))

    def forward(self, x, embed=None):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        # if self.film_layer and embed is not None: x = self.film(x, embed)
        if self.film_layer and embed is not None: 
            beta = self.film(x.size(), embed)
            x = x + beta
            # x = x * (1. + alpha) + beta
        x = x + self.add_beta
        x = self.relu(x)

        # if self.film_layer and embed is not None: 
        #     # alphas, betas = self.film(x.size(), embed)
        #     betas = self.film(x.size(), embed)
        # x = (self.bn1(self.conv1(x)))
        # if self.film_layer and embed is not None: 
        #     x = x + betas[0]
        #     # x = x * (1 + alphas[0]) + betas[0]
        # x = self.relu(x)
        # x = (self.bn2(self.conv2(x)))
        # if self.film_layer and embed is not None: 
        #     x = x + betas[1]
        #     # x = x * (1 + alphas[1]) + betas[1]
        # x = self.relu(x)
        # x = self.bn3(self.conv3(x))
        # if self.film_layer and embed is not None: 
        #     x = x + betas[2]
        #     # x = x * (1 + alphas[2]) + betas[2]
        # x = self.relu(x)

        if self.down_size:
            x = self.max_pool(x)
        return x

class UNet_up_block_proxy(nn.Module):
    def __init__(self, prev_channel, input_channel, output_channel, up_sample=True):
        super().__init__()
        self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, output_channel)
        # self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        # self.bn2 = nn.GroupNorm(8, output_channel)
        # self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
        # self.bn3 = nn.GroupNorm(8, output_channel)
        self.relu = torch.nn.ReLU()
        self.up_sample = up_sample

    def forward(self, prev_feature_map, x):
        if self.up_sample:
            x = self.up_sampling(x)
        x = torch.cat((x, prev_feature_map), dim=1)
        x = self.relu(self.bn1(self.conv1(x)))
        # x = self.relu(self.bn2(self.conv2(x)))
        # x = self.relu(self.bn3(self.conv3(x)))
        return x

# class UNet_up_block(nn.Module):
#     def __init__(self, prev_channel, input_channel, output_channel, up_sample=True):
#         super().__init__()
#         self.up_sampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
#         self.down_sampling = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False)
#         self.conv1 = nn.Conv2d(prev_channel + input_channel, output_channel, 3, padding=1)
#         self.bn1 = nn.GroupNorm(8, output_channel)
#         self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
#         self.bn2 = nn.GroupNorm(8, output_channel)
#         self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
#         self.bn3 = nn.GroupNorm(8, output_channel)
#         self.relu = torch.nn.ReLU()
#         self.up_sample = up_sample

#     def forward(self, prev_feature_map, x):
#         if self.up_sample:
#             x = self.up_sampling(x)
#         else:
#             prev_feature_map = self.down_sampling(prev_feature_map)
#         x = torch.cat((x, prev_feature_map), dim=1)
#         x = self.relu(self.bn1(self.conv1(x)))
#         x = self.relu(self.bn2(self.conv2(x)))
#         x = self.relu(self.bn3(self.conv3(x)))
#         return x


# class UNet_down_block(nn.Module):
#     def __init__(self, input_channel, output_channel, down_size=True):
#         super().__init__()
#         self.conv1 = nn.Conv2d(input_channel, output_channel, 3, padding=1)
#         self.bn1 = nn.GroupNorm(8, output_channel)
#         self.conv2 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
#         self.bn2 = nn.GroupNorm(8, output_channel)
#         self.conv3 = nn.Conv2d(output_channel, output_channel, 3, padding=1)
#         self.bn3 = nn.GroupNorm(8, output_channel)
#         self.max_pool = nn.MaxPool2d(2, 2)
#         self.relu = nn.ReLU()
#         self.down_size = down_size

#     def forward(self, x):
#         x = self.relu(self.bn1(self.conv1(x)))
#         x = self.relu(self.bn2(self.conv2(x)))
#         x = self.relu(self.bn3(self.conv3(x)))
#         if self.down_size:
#             x = self.max_pool(x)
#         return x


class UNet_adapt(TrainableModel):
    def __init__(self,  downsample=6, in_channels=3, out_channels=3, film_layer=False, embed_channel=128, spatial_film=True):
        super().__init__()

        self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample
        self.down1 = UNet_down_block(in_channels, 16, False)
        self.down_blocks = nn.ModuleList(
            [UNet_down_block(2**(4+i), 2**(5+i), True, film_layer=film_layer, embed_channel=embed_channel, spatial_film=spatial_film) for i in range(0, downsample)]
        )

        bottleneck = 2**(4 + downsample)
        self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, bottleneck)
        self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, bottleneck)
        self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, bottleneck)

        self.up_blocks = nn.ModuleList(
            [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i), film_layer=film_layer, embed_channel=embed_channel, spatial_film=spatial_film) for i in range(0, downsample)]
        )

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x, embed=None):
        x = self.down1(x)
        xvals = [x]
        for i in range(0, self.downsample):
            x = self.down_blocks[i](x, embed=embed)
            xvals.append(x)

        x = self.relu(self.bn1(self.mid_conv1(x)))
        x = self.relu(self.bn2(self.mid_conv2(x)))
        x = self.relu(self.bn3(self.mid_conv3(x)))

        for i in range(0, self.downsample)[::-1]:
            x = self.up_blocks[i](xvals[i], x, embed=embed)

        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.last_conv2(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)

class UNet(TrainableModel):
    def __init__(self,  downsample=6, in_channels=3, out_channels=3, up_sample=[True]):
        super().__init__()

        self.in_channels, self.out_channels, self.downsample = in_channels, out_channels, downsample
        self.down1 = UNet_down_block(in_channels, 16, False)
        self.down_blocks = nn.ModuleList(
            [UNet_down_block(2**(4+i), 2**(5+i), True) for i in range(0, downsample)]
        )

        bottleneck = 2**(4 + downsample)
        self.mid_conv1 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn1 = nn.GroupNorm(8, bottleneck)
        self.mid_conv2 = nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn2 = nn.GroupNorm(8, bottleneck)
        self.mid_conv3 = torch.nn.Conv2d(bottleneck, bottleneck, 3, padding=1)
        self.bn3 = nn.GroupNorm(8, bottleneck)

        # if len(up_sample)<downsample: up_sample = up_sample*downsample
        self.up_blocks = nn.ModuleList(
            [UNet_up_block(2**(4+i), 2**(5+i), 2**(4+i), up_sample=up_sample[i]) for i in range(0, downsample)]
        )

        self.last_conv1 = nn.Conv2d(16, 16, 3, padding=1)
        self.last_bn = nn.GroupNorm(8, 16)
        self.last_conv2 = nn.Conv2d(16, out_channels, 1, padding=0)
        # self.last_conv2_rho = nn.Conv2d(16, 3, 1, padding=0)
        self.relu = nn.ReLU()
        # self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.down1(x)
        xvals = [x]
        for i in range(0, self.downsample):
            x = self.down_blocks[i](x)
            xvals.append(x)

        x = self.relu(self.bn1(self.mid_conv1(x)))
        x = self.relu(self.bn2(self.mid_conv2(x)))
        x = self.relu(self.bn3(self.mid_conv3(x)))

        for i in range(0, self.downsample)[::-1]:
            x = self.up_blocks[i](xvals[i], x)

        x = self.relu(self.last_bn(self.last_conv1(x)))
        x = self.last_conv2(x)
        return x

    def loss(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        return loss, (loss.detach(),)



