import torch
import torch.nn as nn
from chip.models.unet import build_unet
from diffusers import UNet2DModel

class CNN_Block(nn.Module):
    def __init__(self,in_channels,out_channels,stride=2):
        super().__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode="reflect"),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2)
            )
    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels * 2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNN_Block(in_channels, feature, stride=1 if feature == features[-1] else 4)

            )
            in_channels = feature
        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            )
        )

        self.head = nn.Linear(12**2, 1, bias=False)

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        ### X = Correct Satellite Image
        ### Y = Correct/Fake Image

        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return self.head(x.reshape(len(x), -1))


class GAN(nn.Module):
    def __init__(self, in_channels=1, disc_features=[64, 128, 256, 512]):
        super().__init__()
        # self.unet = build_unet(in_channels, num_classes=1)
        self.unet = UNet2DModel(
            sample_size=512,  # the target image resolution
            in_channels=1,  # the number of input channels, 3 for RGB images
            out_channels=1,  # the number of output channels
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=(64, 128, 128, 256),  # the number of output channels for each UNet block
            down_block_types=(
                "DownBlock2D",  # a regular ResNet downsampling block
                # "DownBlock2D",
                # "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
                "DownBlock2D",
            ),
            up_block_types=(
                "UpBlock2D",  # a regular ResNet upsampling block
                "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                "UpBlock2D",
                # "UpBlock2D",
                # "UpBlock2D",
                "UpBlock2D",
            ),
        )
        self.discriminator = Discriminator(in_channels, disc_features)