import os
import logging
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import lru_cache
from torch_geometric.nn import GeneralConv
from sklearn.neighbors import NearestNeighbors

from llm_router.router.base_router import RouterInput
from llm_router.model.networks import get_mlp
from llm_router.model.utils import filter_knn_indices

logger = logging.getLogger(__name__)

@lru_cache(32)
def _prompt_to_model_edge(K, M):
    edge_index = []
    for k in range(K):
        for m in range(M):
            edge_index.append([k, m+K])
    edge_index = torch.tensor(edge_index)
    
    return edge_index

@lru_cache(32)
def _construct_edge_index(B, K, M):
    edge_index = []
    offset = 0
    for _ in range(B):
        p2m_edge = _prompt_to_model_edge(K, M)
        edge_index.append(p2m_edge + offset)
        offset = offset + K + M
    edge_index = torch.cat(edge_index, dim=0)
    
    return edge_index.t()

class GraphClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        
        # projection
        self.prompt_proj = nn.Linear(self.config.embed_dim, self.config.proj_dim)
        self.model_embed = nn.Embedding(self.config.n_models, self.config.proj_dim)
        self.edge_proj = nn.Linear(2, self.config.proj_dim)
        
        self.gnn = nn.ModuleList([
            GeneralConv(
                in_channels=self.config.proj_dim,
                out_channels=self.config.proj_dim,
                in_edge_channels=self.config.proj_dim,
                directed_msg=True,
            )
            for _ in range(self.config.gnn_layers)
        ])
        
        self.bn = nn.ModuleList([
            nn.BatchNorm1d(self.config.proj_dim)
            for _ in range(self.config.gnn_layers)
        ])
        
        self.classifier = get_mlp(self.config.proj_dim * 2, self.config.pred_hidden_dims, 1)
        
    @property
    def name(self):
        return "graph_classifier"
    
    def initialize(self, trainset):
        self.benchmark = trainset.get_benchmark()
        
        self.max_cost = self.benchmark["costs"].max().item()
        
        self.knn = NearestNeighbors(
            n_neighbors=self.config.n_neighbors+1, # extra one to exclude itself
            metric="cosine",
            leaf_size=30,
            n_jobs=-1,
        ).fit(self.benchmark["embeddings"].cpu().numpy())
        
    def forward(self, batch):
        B, M = batch["scores"].shape
        
        # construct input
        embed = self.prompt_proj(batch["embeddings"].to(self.config.device)) # [B,d]
        scores = batch["scores"].to(self.config.device) # [B,M]
        costs = batch["costs"].to(self.config.device) # [B,M]
        q = scores - self.config.preference * costs / self.max_cost
        labels = torch.argmax(q, dim=1)
        
        # retrieve neighborhood
        _, indices = self.knn.kneighbors(batch["embeddings"].cpu().numpy()) # [B,k+1]
        indices = filter_knn_indices(indices, batch["idx"]) # [B,k]
        
        neigh_embed = self.prompt_proj(self.benchmark["embeddings"][indices].to(self.config.device)) # [B,k,d]
        neigh_scores = self.benchmark["scores"][indices].to(self.config.device) # [B,k,M]
        neigh_costs = self.benchmark["costs"][indices].to(self.config.device) # [B,k,M]
        neigh_q = neigh_scores - self.config.preference * neigh_costs / self.max_cost # [B,k,M]
        
        # construct graph
        prompts = torch.cat([neigh_embed, embed.unsqueeze(1)], dim=1) # [B,k+1,d]
        models = self.model_embed(torch.arange(self.config.n_models).to(self.config.device)).unsqueeze(0).expand([B,-1,-1]) # [B,M,d]
        x = torch.cat([prompts, models], dim=1).reshape([B*(self.config.n_neighbors+1+M),-1]) # [B*(k+1+M),d]
        edge_index = _construct_edge_index(B, self.config.n_neighbors+1, M)
        edge_index = edge_index.to(self.config.device)
        edge_attr = torch.cat([neigh_q, torch.zeros_like(q).unsqueeze(1)], dim=1) # [B,k+1,M]
        edge_mask = torch.cat([
            torch.ones([B,self.config.n_neighbors,M]),
            torch.zeros([B,1,M])
        ], dim=1).to(self.config.device) # [B, k+1, M]
        edge_attr = torch.stack([edge_attr, edge_mask], dim=-1).reshape([B*(self.config.n_neighbors+1)*M,2]) # [B*(k+1)*M, 2]
        edge_attr = self.edge_proj(edge_attr) # [B*(k+1)*M, d]
        
        # forward
        for i in range(self.config.gnn_layers):
            x = F.relu(self.bn[i](self.gnn[i](x, edge_index, edge_attr)))
        x = x.reshape([B, (self.config.n_neighbors+1+M), -1])
        
        # prediction
        target_prompt = x[:, self.config.n_neighbors].unsqueeze(1).expand([B,M,-1]) # [B,M,d]
        target_model = x[:, -M:] # [B,M,d]
        pred_input = torch.cat([target_prompt, target_model], dim=-1)
        logits = self.classifier(pred_input).squeeze(-1) # [B,M]
        
        # loss
        loss = F.cross_entropy(logits, labels)
        
        return loss
    
    @torch.no_grad()
    def predict(self, router_input: RouterInput):
        B, M = 1, self.benchmark["scores"].shape[1]
        
        # construct input
        embed = self.prompt_proj(router_input.embedding.unsqueeze(0).to(self.config.device)) # [B,d]
        
        # retrieve neighborhood
        _, indices = self.knn.kneighbors(router_input.embedding.unsqueeze(0).cpu().numpy()) # [B,k+1]
        indices = torch.from_numpy(indices)[:,:-1] # [B,k]
        
        neigh_embed = self.prompt_proj(self.benchmark["embeddings"][indices].to(self.config.device)) # [B,k,d]
        neigh_scores = self.benchmark["scores"][indices].to(self.config.device) # [B,k,M]
        neigh_costs = self.benchmark["costs"][indices].to(self.config.device) # [B,k,M]
        neigh_q = neigh_scores - self.config.preference * neigh_costs / self.max_cost # [B,k,M]
        
        # construct graph
        prompts = torch.cat([neigh_embed, embed.unsqueeze(1)], dim=1) # [B,k+1,d]
        models = self.model_embed(torch.arange(self.config.n_models).to(self.config.device)).unsqueeze(0).expand([B,-1,-1]) # [B,M,d]
        x = torch.cat([prompts, models], dim=1).reshape([B*(self.config.n_neighbors+1+M),-1]) # [B*(k+1+M),d]
        edge_index = _construct_edge_index(B, self.config.n_neighbors+1, M)
        edge_index = edge_index.to(self.config.device)
        edge_attr = torch.cat([neigh_q, torch.zeros([B,1,M]).to(self.config.device)], dim=1) # [B,k+1,M]
        edge_mask = torch.cat([
            torch.ones([B,self.config.n_neighbors,M]),
            torch.zeros([B,1,M])
        ], dim=1).to(self.config.device) # [B, k+1, M]
        edge_attr = torch.stack([edge_attr, edge_mask], dim=-1).reshape([B*(self.config.n_neighbors+1)*M,2]) # [B*(k+1)*M, 2]
        edge_attr = self.edge_proj(edge_attr) # [B*(k+1)*M, d]
        
        # forward
        for i in range(self.config.gnn_layers):
            x = F.relu(self.bn[i](self.gnn[i](x, edge_index, edge_attr)))
        x = x.reshape([B, (self.config.n_neighbors+1+M), -1])
        
        # prediction
        target_prompt = x[:, self.config.n_neighbors].unsqueeze(1).expand([B,M,-1]) # [B,M,d]
        target_model = x[:, -M:] # [B,M,d]
        pred_input = torch.cat([target_prompt, target_model], dim=-1)
        logits = self.classifier(pred_input).squeeze(-1) # [B,M]
        
        routing_label = torch.argmax(logits, dim=1)
        routing_label = routing_label.squeeze().item()
        
        return routing_label