import torch.nn as nn
import torch
import os

class DiscriminatorHead(nn.Module):

    def __init__(self, input_channel, output_channel=1):
        super().__init__()
        inner_channel = 1024
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channel, inner_channel, 1, 1, 0),
            nn.GroupNorm(32, inner_channel),
            nn.LeakyReLU(
                inplace=True
            ),  # use LeakyReLu instead of GELU shown in the paper to save memory
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(inner_channel, inner_channel, 1, 1, 0),
            nn.GroupNorm(32, inner_channel),
            nn.LeakyReLU(
                inplace=True
            ),  # use LeakyReLu instead of GELU shown in the paper to save memory
        )

        self.conv_out = nn.Conv2d(inner_channel, output_channel, 1, 1, 0)

    def forward(self, x):
        b, t, c, h, w = x.shape
        x = x.view(b * t, c, h, w)
        x = self.conv1(x)
        x = self.conv2(x) + x
        x = self.conv_out(x)
        x = x.view(b, t, -1, h, w)
        return x


class Discriminator(nn.Module):

    def __init__(
        self,
        stride=8,
        num_h_per_head=1,
        adapter_channel_dims=[3072],
        adapter_out_dim=16,
        total_layers=48,
    ):
        super().__init__()
        adapter_channel_dims = adapter_channel_dims * (total_layers // stride)
        adapter_channel_dims.append(adapter_out_dim)
        self.stride = stride
        self.num_h_per_head = num_h_per_head
        self.head_num = len(adapter_channel_dims)
        self.heads = nn.ModuleList([
            nn.ModuleList([
                DiscriminatorHead(adapter_channel)
                for _ in range(self.num_h_per_head)
            ]) for adapter_channel in adapter_channel_dims
        ])

    def forward(self, features):
        outputs = []

        def create_custom_forward(module):

            def custom_forward(*inputs):
                return module(*inputs)

            return custom_forward

        assert len(features) == len(self.heads)
        for i in range(0, len(features)):
            for h in self.heads[i]:
                out = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(h),
                    features[i],
                    use_reentrant=False
                )
                # out = h(features[i])
                outputs.append(out)  # (b f c h w)
        return outputs

    def save_pretrained(self, output_dir):
        os.makedirs(output_dir, exist_ok=True)
        torch.save(self.state_dict(), os.path.join(output_dir, "discriminator.pt"))