import os
import logging
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.neighbors import NearestNeighbors

from llm_router.router.base_router import RouterInput
from llm_router.model.networks import CAB, ISAB, get_mlp
from llm_router.model.utils import filter_knn_indices

logger = logging.getLogger(__name__)

class NestedISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, use_ln):
        super().__init__()
        
        self.col = ISAB(dim_in, dim_out, num_heads, num_inds, use_ln)
        self.row = ISAB(dim_out, dim_out, num_heads, num_inds, use_ln)
        
    def forward(self, x):
        assert x.ndim == 4
        B,r,c = x.shape[:3]
        
        h = x.reshape([B*r,c,-1])
        h = self.col(h)
        h = h.reshape([B,r,c,-1]).permute([0,2,1,3]).reshape([B*c,r,-1])
        h = self.row(h)
        h = h.reshape([B,c,r,-1]).permute([0,2,1,3])
        
        return h

class NestedCNPPredictor(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        
        # projection
        self.proj = nn.Linear(self.config.embed_dim, self.config.proj_dim)
        
        # self attention layers
        self_attention_layers = []
        input_dim = self.config.proj_dim + 2
        for d in self.config.sa_hidden_dims:
            self_attention_layers.append(NestedISAB(input_dim, d, self.config.num_heads, self.config.num_inds, self.config.use_ln))
            input_dim = d
        self.sa_layers = nn.Sequential(*self_attention_layers)    
    
        # cross attention layers
        self.ca_layer = CAB(self.config.proj_dim, self.config.proj_dim, input_dim, self.config.num_heads, self.config.use_ln)
        
        # prediction layers
        self.score_output = get_mlp(input_dim, self.config.pred_hidden_dims, 1)
        self.token_output = get_mlp(input_dim, self.config.pred_hidden_dims, 1)
    
    @property
    def name(self):
        return "nested_cnp_predictor"
        
    def initialize(self, trainset):
        self.benchmark = trainset.get_benchmark()
        
        self.knn = NearestNeighbors(
            n_neighbors=self.config.n_neighbors+1, # extra one to exclude itself
            metric="cosine",
            leaf_size=30,
            n_jobs=-1,
        ).fit(self.benchmark["embeddings"].cpu().numpy())
        
    def forward(self, batch):
        B, M = batch["scores"].shape
        
        # construct input
        embed = self.proj(batch["embeddings"].unsqueeze(1).expand([-1,M,-1]).to(self.config.device)) # [B,M,d]
        scores = batch["scores"].to(self.config.device) # [B,M]
        tokens = batch["output_tokens"].to(self.config.device) # [B,M]
        
        # retrieve neighborhood
        _, indices = self.knn.kneighbors(batch["embeddings"].cpu().numpy()) # [B,k+1]
        indices = filter_knn_indices(indices, batch["idx"]) # [B,k]
        
        neigh_embed = self.proj(self.benchmark["embeddings"][indices].unsqueeze(1).expand([-1,M,-1,-1]).to(self.config.device)) # [B,M,k,d]
        neigh_scores = self.benchmark["scores"][indices].permute([0,2,1]).to(self.config.device) # [B,M,k]
        neigh_tokens = self.benchmark["output_tokens"][indices].permute([0,2,1]).to(self.config.device) # [B,M,k]
        
        # self attention preprocessing
        sa_input = torch.cat([
            neigh_embed.reshape(B,M,self.config.n_neighbors,-1),
            neigh_scores.reshape(B,M,self.config.n_neighbors,1),
            neigh_tokens.reshape(B,M,self.config.n_neighbors,1) 
        ], dim=-1)
        sa_output = self.sa_layers(sa_input) # [B,M,k,d]
        
        # cross attention prediction
        ca_output = self.ca_layer(
            embed.reshape([B*M,1,-1]), 
            neigh_embed.reshape([B*M,self.config.n_neighbors,-1]), 
            sa_output.reshape([B*M,self.config.n_neighbors,-1])
        ) # [B*M,1,d]
        
        # prediction
        pred_input = ca_output.squeeze(1)
        score_pred = self.score_output(pred_input)
        token_pred = self.token_output(pred_input)
        
        # loss
        score_loss = F.smooth_l1_loss(score_pred.reshape(B,M), scores) * self.config.loss_weights[0]
        token_loss = F.smooth_l1_loss(token_pred.reshape(B,M), tokens) * self.config.loss_weights[1]
        
        return score_loss + token_loss
    
    @torch.no_grad()
    def predict(self, router_input: RouterInput):
        B, M = 1, self.benchmark["scores"].shape[1]
        
        # construct input
        embed = self.proj(router_input.embedding.unsqueeze(0).unsqueeze(0).expand([-1,M,-1]).to(self.config.device)) # [B,M,d]
        
        # retrieve neighborhood
        _, indices = self.knn.kneighbors(router_input.embedding.unsqueeze(0).cpu().numpy()) # [B,k+1]
        indices = torch.from_numpy(indices)[:,:-1] # [B,k]
        
        neigh_embed = self.proj(self.benchmark["embeddings"][indices].unsqueeze(1).expand([-1,M,-1,-1]).to(self.config.device)) # [B,M,k,d]
        neigh_scores = self.benchmark["scores"][indices].permute([0,2,1]).to(self.config.device) # [B,M,k]
        neigh_tokens = self.benchmark["output_tokens"][indices].permute([0,2,1]).to(self.config.device) # [B,M,k]
        
        # self attention preprocessing
        sa_input = torch.cat([
            neigh_embed.reshape(B,M,self.config.n_neighbors,-1),
            neigh_scores.reshape(B,M,self.config.n_neighbors,1),
            neigh_tokens.reshape(B,M,self.config.n_neighbors,1) 
        ], dim=-1)
        sa_output = self.sa_layers(sa_input) # [B,M,k,d]
        
        # cross attention prediction
        ca_output = self.ca_layer(
            embed.reshape([B*M,1,-1]), 
            neigh_embed.reshape([B*M,self.config.n_neighbors,-1]), 
            sa_output.reshape([B*M,self.config.n_neighbors,-1])
        ) # [B*M,1,d]
        
        # prediction
        pred_input = ca_output.squeeze(1)
        score_pred = self.score_output(pred_input)
        token_pred = self.token_output(pred_input)
        
        score_pred = score_pred.squeeze().data.cpu().tolist()
        token_pred = token_pred.squeeze().data.cpu().tolist()
        
        return score_pred, token_pred