import os
import logging
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from llm_router.router.base_router import RouterInput
from llm_router.model.networks import get_mlp

logger = logging.getLogger(__name__)

class MLPMFClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        
        self.model_embed = nn.Embedding(self.config.n_models, self.config.mf_dim)
        self.proj = nn.Linear(self.config.embed_dim, self.config.mf_dim)        
        self.classifier = get_mlp(self.config.mf_dim, self.config.hidden_dims, 1)
        
    @property
    def name(self):
        return "mlp_mf"
    
    def initialize(self, trainset):
        self.benchmark = trainset.get_benchmark()
        
        self.max_cost = self.benchmark["costs"].max().item()
        
    def forward(self, batch):
        B, M = batch["scores"].shape
        
        # construct input
        model_embed = self.model_embed(torch.arange(M).to(self.config.device)) # [M,d]
        prompt_embed = batch["embeddings"].to(self.config.device) # [B,d]
        scores = batch["scores"].to(self.config.device) # [B,M]
        costs = batch["costs"].to(self.config.device) # [B,M]
        
        q = scores - self.config.preference * costs / self.max_cost
        labels = torch.argmax(q, dim=1)
        
        # prediction
        h = model_embed * self.proj(prompt_embed).unsqueeze(1)
        logits = self.classifier(h).squeeze(-1) # [B,M]
        
        # loss
        loss = F.cross_entropy(logits, labels)
        
        return loss
    
    @torch.no_grad()
    def predict(self, router_input: RouterInput):
        B, M = 1, self.config.n_models
        
        # construct input
        model_embed = self.model_embed(torch.arange(M).to(self.config.device)) # [M,d]
        prompt_embed = router_input.embedding.unsqueeze(0).to(self.config.device) # [B,d]
        
        # prediction
        h = model_embed * self.proj(prompt_embed).unsqueeze(1)
        logits = self.classifier(h).squeeze(-1) # [B,M]
        
        routing_label = torch.argmax(logits, dim=1)
        routing_label = routing_label.squeeze().item()
        
        return routing_label