import torch
import torch.nn as nn
import sys


class FTTransformer(nn.Module):
    def __init__(self, categories, num_continuous, dim=64, depth=3, heads=8, dropout=0.1, num_classes=2):
        super().__init__()
        self.num_categories = len(categories)
        self.num_continuous = num_continuous
        self.dim = dim
        self.category_embeddings = nn.ModuleList([
            nn.Embedding(num_embeddings=cat_size, embedding_dim=dim)
            for cat_size in categories
        ])
        self.continuous_projection = nn.Linear(num_continuous, dim) if num_continuous > 0 else None
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4,
                                                   dropout=dropout, activation='relu', batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.classifier = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x_categ, x_cont=None):
        embeddings = [embed(x_categ[:, i]) for i, embed in enumerate(self.category_embeddings)]
        if self.num_continuous > 0 and x_cont is not None:
            embeddings.append(self.continuous_projection(x_cont))
        x = torch.stack(embeddings, dim=1)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.transformer(x)
        return self.classifier(x[:, 0])
    

class MLP(nn.Module):
    def __init__(self, cat_dims, num_continuous, hidden_dim=32, num_classes=2, num_layers=10, dropout=0.1):
        """
        MLP pour données tabulaires, avec Embedding pour les features catégorielles,
        et BatchNorm + Dropout entre chaque couche cachée.
        """
        super().__init__()

        self.num_categ = len(cat_dims)
        self.num_continuous = num_continuous

        # Embedding pour chaque feature catégorielle
        self.embeddings = nn.ModuleList([
            nn.Embedding(cat_dim, min(50, (cat_dim + 1) // 2)) for cat_dim in cat_dims
        ])
        embedding_dim = sum(emb.embedding_dim for emb in self.embeddings)

        input_dim = embedding_dim + num_continuous

        layers = []
        current_dim = input_dim
        for _ in range(num_layers):
            layers.append(nn.Linear(current_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            if dropout > 0:
                layers.append(nn.Dropout(p=dropout))
            current_dim = hidden_dim

        layers.append(nn.Linear(hidden_dim, num_classes))  # Dernière couche sans BatchNorm ni Dropout

        self.model = nn.Sequential(*layers)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x_categ, x_cont):
        if self.num_categ > 0:
            embedded = [emb(x_categ[:, i]) for i, emb in enumerate(self.embeddings)]
            x = torch.cat(embedded, dim=1)
            x = torch.cat([x, x_cont], dim=1)
        else:
            x = x_cont
        return self.model(x)



class ResMLP(nn.Module):
    def __init__(self, cat_dims, num_continuous, hidden_dim=64, num_layers=5, num_classes=2, dropout=0.1):
        super().__init__()
        self.num_categ = len(cat_dims)
        self.num_continuous = num_continuous

        self.embeddings = nn.ModuleList([
            nn.Embedding(cat_dim, min(50, (cat_dim + 1) // 2)) for cat_dim in cat_dims
        ])
        embedding_dim = sum(emb.embedding_dim for emb in self.embeddings)
        input_dim = embedding_dim + num_continuous

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            self.layers.append(nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout) if dropout > 0 else nn.Identity()
            ))

        self.classifier = nn.Linear(hidden_dim, num_classes)
        self._initialize_weights()

    def forward(self, x_categ, x_cont):
        if self.num_categ > 0:
            embedded = [emb(x_categ[:, i]) for i, emb in enumerate(self.embeddings)]
            x = torch.cat(embedded, dim=1)
            x = torch.cat([x, x_cont], dim=1)
        else:
            x = x_cont

        for layer in self.layers:
            x = x + layer(x) 

        return self.classifier(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


class SAINT(nn.Module):
    def __init__(self, cat_dims, num_continuous, dim=64, depth=3, heads=8, dropout=0.1, num_classes=2):
        super().__init__()
        self.embeddings = nn.ModuleList([
            nn.Embedding(cat, dim) for cat in cat_dims
        ])
        self.cont_proj = nn.Linear(num_continuous, dim) if num_continuous > 0 else None
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4,
                                                   dropout=dropout, activation='relu', batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.classifier = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
        self._initialize_weights()

    def forward(self, x_categ, x_cont):
        x_embed = [emb(x_categ[:, i]) for i, emb in enumerate(self.embeddings)]
        if x_cont is not None:
            x_embed.append(self.cont_proj(x_cont))
        x = torch.stack(x_embed, dim=1)
        cls = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat((cls, x), dim=1)
        x = self.transformer(x)
        return self.classifier(x[:, 0])

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
                
    
class FTTransformer(nn.Module):
    def __init__(self, categories, num_continuous, dim=64, depth=3, heads=8, dropout=0.1, num_classes=2):
        super().__init__()
        self.num_categories = len(categories)
        self.num_continuous = num_continuous
        self.dim = dim
        self.category_embeddings = nn.ModuleList([
            nn.Embedding(num_embeddings=cat_size, embedding_dim=dim)
            for cat_size in categories
        ])
        self.continuous_projection = nn.Linear(num_continuous, dim) if num_continuous > 0 else None
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4,
                                                   dropout=dropout, activation='relu', batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.classifier = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
    def forward(self, x_categ, x_cont=None):
        embeddings = [embed(x_categ[:, i]) for i, embed in enumerate(self.category_embeddings)]
        if self.num_continuous > 0 and x_cont is not None:
            embeddings.append(self.continuous_projection(x_cont))
        x = torch.stack(embeddings, dim=1)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.transformer(x)
        return self.classifier(x[:, 0])
    
    

class TabNet_v1(nn.Module):
    def __init__(self, cat_dims, num_continuous, num_classes=2, hidden_dim=64, n_steps=3, gamma=1.5, epsilon=1e-15):
        super().__init__()
        self.cat_dims = cat_dims
        self.num_continuous = num_continuous
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.n_steps = n_steps
        self.gamma = gamma
        self.epsilon = epsilon

        # Embeddings for categorical features
        self.embeddings = nn.ModuleList([
            nn.Embedding(cat_dim, min(50, (cat_dim + 1) // 2)) for cat_dim in cat_dims
        ])
        self.cat_emb_dim = sum([emb.embedding_dim for emb in self.embeddings])

        self.input_dim = self.cat_emb_dim + num_continuous

        self.initial_bn = nn.BatchNorm1d(self.input_dim)

        self.fc = nn.Linear(self.input_dim, hidden_dim)
        self.bn = nn.BatchNorm1d(hidden_dim)
        self.shared_step = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim)
            ) for _ in range(n_steps)
        ])

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, num_classes)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x_categ, x_cont):
        if self.cat_dims:
            embedded = [emb(x_categ[:, i]) for i, emb in enumerate(self.embeddings)]
            x_cat = torch.cat(embedded, dim=1)
        else:
            x_cat = torch.empty((x_cont.size(0), 0), device=x_cont.device)

        x = torch.cat([x_cat, x_cont], dim=1)
        x = self.initial_bn(x)
        x = self.fc(x)
        x = self.bn(x)

        for step in self.shared_step:
            x = step(x)

        out = self.classifier(x)
        return out



class Sparsemax(nn.Module):
    def forward(self, input):
        input = input - input.max(dim=-1, keepdim=True)[0]
        zs = torch.sort(input, dim=-1, descending=True)[0]
        range = torch.arange(1, input.size(-1) + 1, device=input.device).float()
        bound = 1 + range * zs
        cumulative = torch.cumsum(zs, dim=-1)
        is_gt = bound > cumulative
        k = is_gt.sum(dim=-1, keepdim=True)
        tau = (torch.gather(cumulative, dim=-1, index=k - 1) - 1) / k
        output = torch.clamp(input - tau, min=0)
        return output


class FeatureTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.block(x)


class AttentiveTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.fc = nn.Linear(input_dim, hidden_dim)
        self.bn = nn.BatchNorm1d(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.sparsemax = Sparsemax()

    def forward(self, x, prior):
        x = self.fc(x)
        x = self.bn(x)
        x = self.dropout(x)
        x = x * prior
        return self.sparsemax(x)


class TabNetEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_steps=3, gamma=1.5, dropout=0.0):
        super().__init__()
        self.initial_transform = FeatureTransformer(input_dim, hidden_dim, dropout=dropout)
        self.steps = nn.ModuleList([
            FeatureTransformer(input_dim, hidden_dim, dropout=dropout) for _ in range(n_steps)
        ])
        self.attentives = nn.ModuleList([
            AttentiveTransformer(hidden_dim, input_dim, dropout=dropout) for _ in range(n_steps)
        ])
        self.n_steps = n_steps
        self.gamma = gamma

    def forward(self, x):
        B, D = x.size()
        prior = torch.ones(B, D, device=x.device)
        masked_x = self.initial_transform(x)
        output = 0

        for step, (transform, attention) in enumerate(zip(self.steps, self.attentives)):
            mask = attention(masked_x, prior)
            x_step = x * mask
            transformed = transform(x_step)
            output += transformed
            prior = prior * (self.gamma - mask)

        return output


class TabNet_v2(nn.Module):
    def __init__(self, cat_dims, num_continuous, hidden_dim=64, n_steps=3, gamma=1.5, num_classes=2, dropout=0.0):
        super().__init__()
        self.embeddings = nn.ModuleList([
            nn.Embedding(cat_dim, min(50, (cat_dim + 1) // 2)) for cat_dim in cat_dims
        ])
        self.cat_emb_dim = sum(emb.embedding_dim for emb in self.embeddings)
        self.num_continuous = num_continuous
        self.input_dim = self.cat_emb_dim + num_continuous

        self.encoder = TabNetEncoder(
            self.input_dim, hidden_dim, n_steps=n_steps, gamma=gamma, dropout=dropout
        )
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x_categ, x_cont):
        if self.embeddings:
            x_cat = torch.cat([emb(x_categ[:, i]) for i, emb in enumerate(self.embeddings)], dim=1)
        else:
            x_cat = torch.empty((x_cont.size(0), 0), device=x_cont.device)
        x = torch.cat([x_cat, x_cont], dim=1)
        x = self.encoder(x)
        return self.classifier(x)