import torch
import torch.nn as nn
import torch.nn.functional as F

from xad.models.bases import ConceptNN


class Concept_CNN64(ConceptNN):
    def __init__(self, concepts: int, rep_dim=256, grayscale=False):
        super().__init__(concepts)
        bias = True
        self.rep_dim = rep_dim

        self.encoder = nn.Sequential(
            nn.Conv2d(3 if not grayscale else 1, 32, 3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(32, affine=bias),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(64, affine=bias),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(128, affine=bias),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=bias),
            nn.BatchNorm2d(256, affine=bias),
            nn.LeakyReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(4096, 1024, bias=bias),
            nn.BatchNorm1d(1024, affine=bias),
            nn.LeakyReLU()
        )
        self.linear = nn.Sequential(  # TODO fuse earlier?
            nn.Linear(2048, self.rep_dim, bias=bias),
            nn.BatchNorm1d(self.rep_dim, affine=bias),
            nn.LeakyReLU(),
            nn.Linear(self.rep_dim, self.n_concepts)
        )

    def parameterize(self):
        pass

    def forward(self, x, x_hat):
        fx = self.encoder(x)
        fx_hat = self.encoder(x_hat)
        logits = self.linear(torch.cat([fx, fx_hat], dim=1))
        return logits


class Concept_Projector(ConceptNN):
    def __init__(self, concepts: int, in_dim=512, rep_dim=256, grayscale=False):
        super().__init__(concepts)
        bias = True
        self.in_dim = in_dim
        self.rep_dim = rep_dim

        self.encoder = nn.Sequential(
            nn.Linear(self.in_dim, 512, bias=bias),
            nn.BatchNorm1d(512, affine=bias),
            nn.LeakyReLU(),
            nn.Linear(512, 256, bias=bias),
            nn.BatchNorm1d(256, affine=bias),
            nn.LeakyReLU()
        )
        self.linear = nn.Sequential(  # TODO fuse earlier?
            nn.Linear(512, self.rep_dim, bias=bias),
            nn.BatchNorm1d(self.rep_dim, affine=bias),
            nn.LeakyReLU(),
            nn.Linear(self.rep_dim, self.n_concepts)
        )

    def forward(self, x, x_hat):
        fx = self.encoder(x)
        fx_hat = self.encoder(x_hat)
        logits = self.linear(torch.cat([fx, fx_hat], dim=1))
        return logits

    def parameterize(self):
        pass
