import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from copy import deepcopy
from timm.models.vision_transformer import VisionTransformer

class Boolean(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, depth):
        super(Boolean, self).__init__()

        layers = []
        for _ in range(depth - 1):
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU(inplace=True))
            input_dim = hidden_dim

        layers.append(nn.Linear(input_dim, output_dim))
        layers.append(nn.Tanh())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class ResNet_wrapper(nn.Module):
    def __init__(self, model):
        super(ResNet_wrapper, self).__init__()
        self.model = model
        self.to('cuda')

    def forward(self, x, get_all_features=False):
        
        all_embeddings = {}
        out = F.relu(self.model.bn1(self.model.conv1(x)))
        out = self.model.layer1(out)
        all_embeddings['l1'] = out
        out = self.model.layer2(out)
        all_embeddings['l2'] = out
        out = self.model.layer3(out)
        all_embeddings['l3'] = out
        out = self.model.layer4[0](out)
        all_embeddings['l4'] = out
        out = self.model.layer4[1](out)
        all_embeddings['l5'] = out

        try: 
            out = self.model.layer4[2](out)
            all_embeddings['l6'] = out
        except: pass
        out = F.adaptive_avg_pool2d(out, (1, 1))
        embeddings = out.view(out.size(0), -1)
        if 'l6' in all_embeddings.keys(): all_embeddings['l7'] = embeddings
        else: all_embeddings['l6'] = embeddings

        try:
            out = self.model.linear(embeddings)
        except:
            out = self.model.fc(embeddings)

        if get_all_features: return out, all_embeddings
        return out

class ViTWrapper(nn.Module):
    def __init__(self, model):
        super(ViTWrapper, self).__init__()
        self.model = model
        self.to('cuda')

    def forward(self, x, get_all_features=False):
        all_embeddings = {}
        out = self.model.patch_embed(x)
        for idx, block in enumerate(self.model.blocks):
            out = block(out)
            all_embeddings[f'l{idx+1}'] = out

        out = self.model.norm(out)
        all_embeddings['l13'] = out
        # out = self.model.final(out)

        if get_all_features: return out, all_embeddings
        return out

class ResNetTails(nn.Module):
    def __init__(self, model, start_at=0):
        super(ResNetTails, self).__init__()
        self.model = deepcopy(model)
        self.start_at = start_at
        try: self.layer_list = [self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4[0], self.model.layer4[1], self.model.layer4[2]]
        except: self.layer_list = [self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4[0], self.model.layer4[1]]


    def forward(self, z):
        for idx, layer in enumerate(self.layer_list):
            if self.start_at > (idx + 1):
                continue
            z = layer(z)
        # z = F.avg_pool2d(z, 4)
        z = F.adaptive_avg_pool2d(z, (1, 1))
        z = z.view(z.size(0), -1)
        return z

class ViTTails(nn.Module):
    def __init__(self, model, start_at=0):
        super(ViTTails, self).__init__()
        self.model = deepcopy(model)
        self.start_at = start_at
        # self.layer_list = [self.model.base.encoder.layer[i] for i in range(12)] # 0, 1, 2, 3
        self.layer_list = [self.model.blocks[i] for i in range(12)]

    def forward(self, z):
        for idx, layer in enumerate(self.layer_list):
            if self.start_at > (idx + 1):
                continue
            z = layer(z)
        z = self.model.norm(z)
        z = self.model.fc_norm(z)
        z = self.model.head_drop(z)
        # batch_size, num_patches, dim
        z = z[:, 0, :]
        return z


class Separable(nn.Module):
    def __init__(self, f1, final_size=512, dataset="cifar10", test_mode="class"):
        super(Separable, self).__init__()
        self._f1 = f1
        self._f2 = mlp(1, 512 * 4, 512, 1)

        self.proj1 = mlp(final_size, 2048, 512, 1)
        self.proj2 = nn.Identity()

    def forward(self, x, y):
        scores = torch.matmul(self.proj2(self._f2(y)), self.proj1(self._f1(x)).t())
        return scores

def mlp(input_dim, hidden_dim, output_dim, depth):
    if depth == 1:
        layers = [nn.Linear(input_dim, output_dim)]
    else:
        layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()]
        for _ in range(depth - 2):
            layers += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]
        layers += [nn.Linear(hidden_dim, output_dim)]

    return nn.Sequential(*layers)

def infonce_lower_bound(scores):
    nll = scores.diag().mean() - scores.logsumexp(dim=1)
    mi = torch.tensor(scores.size(0)).float().log() + nll
    mi = mi.mean()
    return mi / math.log(2)