import os
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, features=[64, 128, 256, 512]):
        super(Generator, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        # Encoding path
        for feature in features:
            self.encoder.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Decoding path
        for feature in reversed(features):
            self.decoder.append(
                nn.ConvTranspose3d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.decoder.append(DoubleConv(feature * 2, feature))

        # Final 3D to 2D transition
        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        # Encoder
        for down in self.encoder:
            x = down(x)
            skip_connections.append(x)
            x = F.max_pool3d(x, kernel_size=2, stride=2)

        x = self.bottleneck(x)

        # Decoder
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)
            skip_connection = skip_connections[idx // 2]

            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])

            x = torch.cat((skip_connection, x), dim=1)
            x = self.decoder[idx + 1](x)

        # Final layer
        x = self.final_conv(x)  # Shape (batch, out_channels, h, d, d)
        x = torch.mean(x, dim=2)  # Reduce the height dimension -> (batch, out_channels, d, d).
        # Doesn't actually do anything for a history size of 16 and features length of 4.
        return x

    def save_gan(self, epoch, seed, optimiser_g, optimiser_v):
        os.makedirs(f'output/gan/{seed}/', exist_ok=True)
        torch.save({
            'epoch': epoch,
            'model': self.state_dict(),
            'optimizer_g': optimiser_g.state_dict(),
            'optimizer_v': optimiser_v.state_dict()
        }, f'output/gan/{seed}/{epoch}.pt')

    def save(self, epoch, optimiser):
        torch.save({
            'epoch': epoch,
            'model': self.state_dict(),
            'optimizer': optimiser.state_dict()
        }, f'output/ar/{epoch}.pt')

    def load(self, epoch, method='gan'):
        checkpoint = torch.load(f'output/{method}/{epoch}.pt')
        self.load_state_dict(checkpoint['model'])
        return checkpoint['epoch'], checkpoint['optimizer_g'], checkpoint['optimizer_v']


class Conv3DEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv3d = nn.Sequential(
            nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # x: (B, C, T, H, W) -> output: (B, hidden_channels, T, H, W)
        return self.conv3d(x)


class Residual2D(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return x + self.block(x)


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor=2):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * scale_factor ** 2, kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.up(x)


class SuperResGenerator(nn.Module):
    def __init__(self, history, in_channels=1, out_channels=1, base_channels=64):
        super().__init__()
        self.encoder3d = Conv3DEncoder(in_channels, base_channels)
        self.temporal_merge = nn.Conv3d(base_channels, base_channels, kernel_size=(history, 1, 1))

        self.res_blocks = nn.Sequential(
            Residual2D(base_channels),
            Residual2D(base_channels),
            Residual2D(base_channels)
        )

        self.upsample = nn.Sequential(
            UpsampleBlock(base_channels, scale_factor=2),
            UpsampleBlock(base_channels, scale_factor=2)
        )

        self.output = nn.Conv2d(base_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        # x: (B, C, T, H, W)
        feat3d = self.encoder3d(x)  # (B, F, T, H, W)
        feat2d = self.temporal_merge(feat3d).squeeze(2)  # (B, F, H, W)
        feat2d = self.res_blocks(feat2d)
        upscaled = self.upsample(feat2d)
        return self.output(upscaled)  # (B, out_channels, 4H, 4W)

    def save_gan(self, epoch, seed, optimiser_g, optimiser_v):
        os.makedirs(f'output/gan/{seed}/', exist_ok=True)
        torch.save({
            'epoch': epoch,
            'model': self.state_dict(),
            'optimizer_g': optimiser_g.state_dict(),
            'optimizer_v': optimiser_v.state_dict()
        }, f'output/gan/{seed}/{epoch}.pt')

    def save(self, epoch, optimiser):
        torch.save({
            'epoch': epoch,
            'model': self.state_dict(),
            'optimizer': optimiser.state_dict()
        }, f'output/ar/{epoch}.pt')

    def load(self, epoch, method='gan'):
        checkpoint = torch.load(f'output/{method}/{epoch}.pt')
        self.load_state_dict(checkpoint['model'])
        return checkpoint['epoch'], checkpoint['optimizer_g'], checkpoint['optimizer_v']