from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging

import numpy as np
import torch
import torch.nn as nn

from .modules import MultiHeadSelfAttention

# original scorenet
class ScoreNet(nn.Module):
    def __init__(self, cfg, input_channel, **kwargs):
        super(ScoreNet, self).__init__()
        hidden = cfg.model['hidden_layers']
        self.l1 = torch.nn.Linear(input_channel, hidden, bias=True)
        self.l2 = torch.nn.Linear(hidden, hidden, bias=True)
        self.l3 = torch.nn.Linear(hidden, 1, bias=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x, s):
        x1 = self.relu(self.l1(x))
        x2 = self.relu(self.l2(x1))
        y_pred = self.l3(x2) + 0.0 * s
        return y_pred

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

# original scorenet + one sigmoid layer
class ScoreNet_sigmoid(nn.Module):
    def __init__(self, cfg, input_channel, **kwargs):
        super(ScoreNet_sigmoid, self).__init__()
        hidden = cfg.model['hidden_layers']
        self.l1 = torch.nn.Linear(input_channel, hidden, bias=True)
        self.l2 = torch.nn.Linear(hidden, hidden, bias=True)
        self.l3 = torch.nn.Linear(hidden, 1, bias=True)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x, s):
        x1 = self.relu(self.l1(x))
        x2 = self.relu(self.l2(x1))
        y_pred = self.l3(x2)
        return self.sigmoid(y_pred) + 0.0 * s

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

# original_scorenet + residual connection
class ScoreNetV2(nn.Module):
    def __init__(self, cfg, input_channel, **kwargs):
        super(ScoreNetV2, self).__init__()
        hidden = cfg.model['hidden_layers']
        self.l1 = torch.nn.Linear(input_channel, hidden, bias=True)
        self.l2 = torch.nn.Linear(hidden, hidden, bias=True)
        #self.l3 = torch.nn.Linear(hidden, hidden, bias=True)
        # self.l4 = torch.nn.Linear(hidden, hidden, bias=True)
        self.l5 = torch.nn.Linear(hidden, 1, bias=True)
        self.relu = torch.nn.ReLU()

    def forward(self, p, s):
        p1 = self.relu(self.l1(p))
        p2 = self.relu(self.l2(p1))
        # p3 = self.relu(self.l3(p2))
        # p4 = self.relu(self.l4(p3))
        y_pred = self.l5(p2) + s
        return y_pred

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)


# original_scorenet + residual connection + batch-level MSHA
class ScoreNetV3(nn.Module):
    def __init__(self, cfg, input_channel, **kwargs):
        super(ScoreNetV3, self).__init__()
        hidden = cfg.model['hidden_layers']
        self.l1 = torch.nn.Linear(input_channel, hidden, bias=True)
        self.l2 = torch.nn.Linear(hidden, hidden, bias=True)
        self.MHSA = MultiHeadSelfAttention(embed_dim=hidden, num_heads=1)  
        self.l3 = torch.nn.Linear(hidden, 1, bias=True)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, p, s):
        p1 = self.relu(self.l1(p))
        p2 = self.relu(self.l2(p1))
        p3 = self.MHSA(p2.unsqueeze(0))[0] + p2
        y_pred = self.l3(p3) 
        return self.sigmoid(y_pred) + s * 0.0

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)


# original_scorenet + residual connection + batch-level MSHA
class ScoreNetTransformer(nn.Module):
    def __init__(self, cfg, input_channel, **kwargs):
        super(ScoreNetTransformer, self).__init__()
        embed_dim = cfg.model['hidden_layers']
        input_dim = cfg.model['feature_channels']
        self.linear = nn.Sequential(
            nn.Linear(input_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=1, dropout=0.2)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)
        self.rank_head = nn.Linear(embed_dim, 1)
        
    def forward(self, p, s):
        x = self.linear(p)
        x = self.encoder(x)
        r = self.rank_head(x)
       
        return r.sigmoid() + 0.0*s 

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

class ScoreNetTransformerSmall(nn.Module):
    def __init__(self, cfg, input_channel, **kwargs):
        super(ScoreNetTransformerSmall, self).__init__()
        embed_dim = cfg.model['hidden_layers']
        input_dim = cfg.model['feature_channels']
        self.linear = nn.Sequential(
            nn.Linear(input_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=1, dropout=0.2)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.rank_head = nn.Linear(embed_dim, 1)
        
    def forward(self, p, s):
        x = self.linear(p)
        x = self.encoder(x)
        r = self.rank_head(x)
       
        return r.sigmoid() + 0.0*s 

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)


class ScoreNetTransformerExSmall(nn.Module):
    def __init__(self, cfg, input_channel, **kwargs):
        super(ScoreNetTransformerExSmall, self).__init__()
        embed_dim = cfg.model['hidden_layers']
        input_dim = cfg.model['feature_channels']
        self.linear = nn.Sequential(
            nn.Linear(input_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=1, dropout=0.2)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.rank_head = nn.Linear(embed_dim, 1)
    

    # def self_only_mask(self, seq_len, boolean=True):
    #     if boolean:
    #         # eye() == 1 on the diagonal; invert so diag → False, off‑diag → True
    #         return ~torch.eye(seq_len, dtype=torch.bool)
    #     else:
    #         mask = torch.full((seq_len, seq_len), float('-inf'))
    #         mask.fill_diagonal_(0.0)
    #     return mask

    def forward(self, p, s):
        x = self.linear(p)
        #b,n,c=x.shape
        #mask = self.self_only_mask(b).to(x.device)
        x = self.encoder(x)
        r = self.rank_head(x)
       
        return r.sigmoid() + 0.0*s 

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)


def get_score_net(cfg, input_channel, is_train):
    #model = ScoreNetTransformer(cfg, input_channel)
    #model =  ScoreNetTransformerSmall(cfg, input_channel)
    model =  ScoreNetTransformerExSmall(cfg, input_channel)
    #model = ScoreNetV2(cfg, input_channel)
    if is_train and cfg.model['init_weights']:
        model.init_weights()
    return model