import os
import logging
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from llm_router.router.base_router import RouterInput

logger = logging.getLogger(__name__)

class LinearMFPredictor(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        
        self.model_embed = nn.Embedding(self.config.n_models, self.config.mf_dim)
        self.proj = nn.Linear(self.config.embed_dim, self.config.mf_dim)
        self.score_pred = nn.Linear(self.config.mf_dim, 1)
        self.token_pred = nn.Linear(self.config.mf_dim, 1)
        
    @property
    def name(self):
        return "linear_mf"
    
    def initialize(self, trainset):
        pass
        
    def forward(self, batch):
        B, M = batch["scores"].shape
        
        # construct input
        model_embed = self.model_embed(torch.arange(M).to(self.config.device)) # [M,d]
        prompt_embed = batch["embeddings"].to(self.config.device) # [B,d]
        scores = batch["scores"].to(self.config.device) # [B,M]
        tokens = batch["output_tokens"].to(self.config.device) # [B,M]
        
        # prediction
        h = model_embed * self.proj(prompt_embed).unsqueeze(1)
        score_pred = self.score_pred(h) # [B,M,1]
        token_pred = self.token_pred(h) # [B,M,1]
        
        # loss
        score_loss = F.smooth_l1_loss(score_pred.reshape(B,M), scores) * self.config.loss_weights[0]
        token_loss = F.smooth_l1_loss(token_pred.reshape(B,M), tokens) * self.config.loss_weights[1]
        
        return score_loss + token_loss
        
    @torch.no_grad()
    def predict(self, router_input: RouterInput):
        B, M = 1, self.config.n_models
        
        # construct input
        model_embed = self.model_embed(torch.arange(M).to(self.config.device)) # [M,d]
        prompt_embed = router_input.embedding.unsqueeze(0).to(self.config.device) # [B,d]
        
        # prediction
        h = model_embed * self.proj(prompt_embed).unsqueeze(1)
        score_pred = self.score_pred(h) # [B,M,1]
        token_pred = self.token_pred(h) # [B,M,1]
        
        score_pred = score_pred.squeeze().data.cpu().tolist()
        token_pred = token_pred.squeeze().data.cpu().tolist()
        
        return score_pred, token_pred