import torch
import torch.nn as nn
import logging
from transformers import BertConfig

kernel_sizes = [7, 5, 3, 3, 3, 3, 3, 3]
depths = [128, 256, 512, 1024, 1024, 2048, 2048, 2048]

class Bert(nn.Module):
    def __init__(self, emb_dim, no_layers):
        super(Bert, self).__init__()
        self.config = BertConfig(
            hidden_size=emb_dim,
            num_hidden_layers=no_layers,
            num_attention_heads=6,
        )
        self.hidden_size = self.config.hidden_size
        self.num_layers = self.config.num_hidden_layers
        self.num_heads = self.config.num_attention_heads
        self.num_hidden_layers= self.config.num_hidden_layers

        # transformer layers
        self.transformer_layers = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(
                    self.config.hidden_size,
                    self.config.num_attention_heads,
                    dim_feedforward=4 * self.config.hidden_size,
                    activation="gelu",
                )
                for _ in range(self.config.num_hidden_layers)
            ]
        )


    def forward(self, input_ids, attention_mask):
        x = input_ids.permute(1, 0, 2)
        # transformer layers
        for i, layer in enumerate(self.transformer_layers):
            x = layer(x, src_key_padding_mask=attention_mask)
        
        x = x[0,:,:]
        x = x.reshape(-1, self.hidden_size)
        return x

class Block3D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1,):
        super().__init__()
        self.conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=False,
        )
       
        self.bn = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)

    
class ProteinVista(nn.Module):
    def __init__(
        self,
        channels=5,
        output_dim=1,
        num_layers=4,
        kernel_sizes=None,
        depths=None,
        hidden_layer_dim=256,
        small_molecule_dim=768,
    ):
        super().__init__()
        self.smiles_bert = Bert(emb_dim=small_molecule_dim, no_layers=4)

        if len(kernel_sizes) != num_layers:
            raise ValueError("Length of kernel_sizes must equal num_layers.")

        if len(depths) != num_layers:
            raise ValueError("Length of depths must equal num_layers.")

        self.small_molecule_dim = small_molecule_dim
        self.layers = nn.ModuleList()
        in_channels = channels

        for i in range(num_layers):
            out_channels = depths[i]
            kernel_size = kernel_sizes[i]
            padding = kernel_size // 2

            if i == 0:
                layer = nn.Sequential(
                    Block3D(
                        in_channels,
                        out_channels,
                        kernel_size=kernel_size,
                        stride=2,
                        padding=padding,
                    ),
                    nn.MaxPool3d(kernel_size=2, stride=2)
                )
            else:
                if num_layers <=5 or i %2 ==0:
                    layer = nn.Sequential(
                        Block3D(
                            in_channels,
                            out_channels,
                            kernel_size=kernel_size,
                            stride=1,
                            padding=padding,
                        ),
                        Block3D(
                            out_channels,
                            out_channels,
                            kernel_size=kernel_size,
                            stride=1,
                            padding=padding,
                        ),
                        nn.MaxPool3d(kernel_size=2, stride=2),
                    )
                else:
                    layer = nn.Sequential(
                        Block3D(
                            in_channels,
                            out_channels,
                            kernel_size=kernel_size,
                            stride=1,
                            padding=padding,
                        ),
                        Block3D(
                            out_channels,
                            out_channels,
                            kernel_size=kernel_size,
                            stride=1,
                            padding=padding,
                        )
                    )
                    
            self.layers.append(layer)
            in_channels = out_channels

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))


        self.fc = nn.Sequential(
            nn.Linear(depths[-1] + self.small_molecule_dim, hidden_layer_dim, bias=True),
            nn.BatchNorm1d(hidden_layer_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(hidden_layer_dim, output_dim),
        )

    def forward(self, x, sm = None, sm_attn_mask = None, return_repr=False):
        for layer in self.layers:
            x = layer(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        if return_repr:
            return x
        if sm is not None:
            sm = self.smiles_bert(sm, sm_attn_mask)
            x = torch.cat((x, sm), 1)
        x = self.fc(x)
        return x


class ProjectionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim=2048, output_dim=256):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x):
        return self.projection(x)


def load_pretrained_model(current_model, pretrained_state_dict):
    current_state_dict = current_model.state_dict()
    new_state_dict = {}
    
    for name, param in pretrained_state_dict.items():
        name = name.replace("module.", "")
        if name in current_state_dict:
            if current_state_dict[name].shape == param.shape:
                new_state_dict[name] = param
            else:
                logging.info(f"Skipping parameter {name} due to shape mismatch.")
        else:
            logging.info(f"Parameter {name} not found in current model.")

    current_model.load_state_dict(new_state_dict, strict=False)
    return current_model

