import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils.get_embed.constant import MODEL_MAP
from utils.data import modaility_map


def listnet_loss(y_true, y_pred):
    # Compute the softmax probabilities for true and predicted scores
    P_true = F.softmax(y_true.float(), dim=1)
    P_pred = F.softmax(y_pred.float(), dim=1)
    
    # Compute the cross entropy loss
    loss = -torch.sum(P_true * torch.log(P_pred + 1e-15), dim=1)  # Adding a small constant to avoid log(0)
    
    return torch.mean(loss)

def cosine_similarity_loss(y_true, y_pred):
    # Compute cosine similarity between the two tensors
    cos_sim = F.cosine_similarity(y_true, y_pred)
    loss = 1 - cos_sim.mean()

    return loss


def get_loss_func(head_model_type):
    if head_model_type == 'ListNet':
        return listnet_loss
    elif head_model_type == 'CNN':
        return listnet_loss
    elif head_model_type == 'RNN':
        return listnet_loss
    else:
        raise ValueError(f"Invalid loss type: {head_model_type}")


def get_optimizer(optimizer_type, model_parameter, lr):
    if optimizer_type == 'SGD':
        return optim.SGD(model_parameter, lr=lr)
    elif optimizer_type == 'Adam':
        return optim.Adam(model_parameter, lr=lr)
    elif optimizer_type == 'Adagrad':
        return optim.Adagrad(model_parameter, lr=lr, initial_accumulator_value=0)
    else:
        raise ValueError(f"Invalid optimizer type: {optimizer_type}")


def get_model(head_model_type, device, cfg):
    if head_model_type == 'ListNet':
        model = ListNet(
            input_size=cfg.num_features,
            hidden_size=cfg.hidden_size,
            cfg=cfg,
        ).to(device)
    elif head_model_type == 'CNN':
        model = CNN(
            channel=cfg.channel,
            kernel_size=cfg.kernel_size,
            stride=cfg.stride,
            num_features=cfg.num_features,
            hidden_size=cfg.hidden_size,
        ).to(device)
    elif head_model_type == "RNN":
        model = RNN(
            input_size=cfg.num_features,
            hidden_size=cfg.hidden_size,
            num_layers=cfg.rnn_layers,
        ).to(device)
    else:
        raise ValueError(f"Invalid head model type: {head_model_type}")

    return model


class ListNet(nn.Module):
    def __init__(self, input_size, hidden_size, cfg=None):
        super(ListNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        backbone = cfg.backbone if cfg and hasattr(cfg, 'backbone') else None
        self.backbone_name = backbone
        if backbone is not None:
            self.backbone = MODEL_MAP(backbone, cfg.seq_type, cfg)
            self.seq_type = cfg.seq_type
            if backbone == 'LucaOne':
                for param in self.backbone.model['model'].parameters():
                    param.requires_grad = True
            else:
                for param in self.backbone.model.parameters():
                    param.requires_grad = True
    
    def batchseq2seqs(self, seqs):
        bs = len(seqs[1])
        seqs_batch = [[] for _ in range(bs)]
        for ki in seqs:
            for b, seq in enumerate(ki):
                seqs_batch[b].append(seq)
        seqs_view = [modaility_map(self.seq_type, seq) for batch in seqs_batch for seq in batch]
        x = self.backbone(seqs_view).cuda()
        x = x.reshape(bs, x.shape[0] // bs, x.shape[1], x.shape[2])
        return x
    
    def forward(self, x):
        if self.backbone is not None:
            x = self.batchseq2seqs(x)  # x: List[Tuple[str]], (k, bs, L) -> torch.Tensor(bs, k, L, D)

        if x.dim() == 4:
            x = x.view(x.size(0), x.size(1), -1)  # x: (bs, k, L, D) -> (bs, k, L*D)
        x = F.relu(self.fc1(x))  # x: (bs, k, L*D) -> (bs, k, hidden_size)
        x = F.relu(self.fc2(x))  # x: (bs, k, hidden_size) -> (bs, k, hidden_size)
        x = self.fc3(x).squeeze(-1)  # x: (bs, k, hidden_size) -> (bs, k, 1) -> (bs, k)

        return x


class CNN(nn.Module):
    def __init__(self, channel, kernel_size, stride, num_features, hidden_size):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv1d(in_channels=1, out_channels=channel, kernel_size=kernel_size, stride=stride)
        self.pool = nn.MaxPool1d(kernel_size=stride)

        linear_input_size = (((num_features - kernel_size) // stride + 1) // stride) * channel
        print("linear_input_size", linear_input_size)
        self.fc = nn.Linear(linear_input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)
        
    def forward(self, x):
        if x.dim() == 4:
            x = x.view(x.size(0), x.size(1), -1)
        batch_size, seq_len, feature_len = x.size()
        x = x.view(batch_size * seq_len, 1, feature_len)

        x = F.relu(self.conv1(x))
        x = self.pool(x)

        x = x.view(batch_size, seq_len, -1)
        x = F.relu(self.fc(x))
        x = self.fc2(x).squeeze()
        return x


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super(RNN, self).__init__()
        self.rnn = nn.RNN(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
        )
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x):
        if x.dim() == 4:
            x = x.view(x.size(0), x.size(1), -1)
        x, _ = self.rnn(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x).squeeze()

        return x
