from typing import Any, Dict, List, Tuple
import torch
from copy import deepcopy
from transformers import AutoModelForCausalLM, AutoTokenizer

from .util import get_tokenizer
from .melo import LORA
from ...util import nethook
import torch

import torch
import random
from collections import defaultdict
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import random

class EntityLinker:
    def __init__(self,model,device):
        self.tokenizer = AutoTokenizer.from_pretrained("models--meta-llama--Meta-Llama-3-8B-Instruct")
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = model
        self.device = device
        self.model.to(device)

    def _get_embedding(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs,output_hidden_states=True)
        return outputs.hidden_states[-1].mean(dim=1).squeeze().cpu().numpy()

    def is_same(self, subject1, subject2):
        embedding1 = self._get_embedding(subject1)
        embedding2 = self._get_embedding(subject2)
        similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
        return similarity > 0.8

def allocate_requests_to_experts(requests, entity_linker ,alpha=0.1, beta=1, tau=0.8):
    n = len(requests)
    h_list = [entity_linker._get_embedding( r['prompt']) for r in requests]
    conflict_matrix = np.zeros((n, n))
    
    for i in range(n):
        for j in range(i+1, n):
  
            sub_conflict = 1 if entity_linker.is_same(requests[i]['subject'], requests[j]['subject']) else 0
        
            h_i, h_j = h_list[i], h_list[j]
            proj_sim = np.abs(np.dot(h_i, h_j) / (np.linalg.norm(h_i) * np.linalg.norm(h_j)))

            total_conflict = alpha * proj_sim + beta * sub_conflict
            print("total_conflict: ",total_conflict)
            if total_conflict >= tau:
                conflict_matrix[i][j] = conflict_matrix[j][i] = 1
    from collections import defaultdict

    graph = defaultdict(set)
    for i in range(n):
        for j in range(n):
            if conflict_matrix[i][j] == 1:
                graph[i].add(j)
                graph[j].add(i)
    

    def greedy_coloring(graph, n):
        colors = [-1] * n
        # nodes_ordered = sorted(graph.keys(), key=lambda x: len(graph[x]), reverse=True)
        for node in range(n):

            used_colors = {colors[neighbor] for neighbor in graph[node] if colors[neighbor] != -1}
            
 
            for color in range(n):
                if color not in used_colors:
                    colors[node] = color
                    break
        return colors

    def greedy_coloring_2(graph, n):
        colors = [-1] * n
        nodes_ordered = sorted(graph.keys(), key=lambda x: len(graph[x]), reverse=True)
        for node in range(n):
 
            used_colors = {colors[neighbor] for neighbor in nodes_ordered[node] if colors[neighbor] != -1}

            for color in range(n):
                if color not in used_colors:
                    colors[node] = color
                    break
        return colors
    
    expert_allocation = greedy_coloring(graph, n)
    return expert_allocation


def apply_melo_to_model(
        model: AutoModelForCausalLM,
        tok: AutoTokenizer,
        requests: List[Dict],
        hparams: CAKEHyperParams,
        copy=True,
        return_orig_weights=False,
        keep_original_weight=False,
        **kwargs: Any,
) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]:
    # only support single edit.we will support sequence edit soon
    # if copy:
    #     model=deepcopy(model)
    weights_copy = {}
    device = torch.device(f'cuda:{hparams.device}')
    tokenizer = get_tokenizer(hparams)
    entity_linker = EntityLinker(model,device)
    expert_allocation = allocate_requests_to_experts(requests,entity_linker) #专家分配索引list
    import logging
    LOG = logging.getLogger(__name__)
    LOG.info(f'expert_allocation is: {expert_allocation}')
    if not isinstance(model, LORA):
        editor = LORA(model, hparams,tokenizer,expert_allocation)
        #editor = LORA(model, hparams,tokenizer)
    else:

        editor = model


    list_of_tokens = []

    for request in requests:
        tokens = tokenizer(request,tok,device)
        list_of_tokens.append(tokens)

    def pad_tokens(list_of_tokens, pad_token_id=0, device=torch.device(f'cuda:{hparams.device}')):
        # Find max length across all sequences (input_ids, labels, etc.)
        max_length = max(
            max(tokens["input_ids"].shape[-1] for tokens in list_of_tokens),
            max(tokens["labels"].shape[-1] for tokens in list_of_tokens)
        )

        padded_tokens = []
        for tokens in list_of_tokens:
            # Pad input_ids
            input_ids = tokens["input_ids"].to(device)
            input_pad_length = max_length - input_ids.shape[-1]
            padded_input_ids = torch.cat([
                input_ids,
                torch.full((*input_ids.shape[:-1], input_pad_length), 
                    pad_token_id, dtype=input_ids.dtype, device=device)
            ], dim=-1)

            # Pad attention_mask
            attention_mask = tokens["attention_mask"].to(device)
            mask_pad_length = max_length - attention_mask.shape[-1]
            padded_attention_mask = torch.cat([
                attention_mask,
                torch.zeros((*attention_mask.shape[:-1], mask_pad_length),
                    dtype=attention_mask.dtype, device=device)
            ], dim=-1)

            # Pad labels (using -100 as padding value, typical for ignore_index)
            labels = tokens["labels"].to(device)
            label_pad_length = max_length - labels.shape[-1]
            padded_labels = torch.cat([
                labels,
                torch.full((*labels.shape[:-1], label_pad_length),
                    -100, dtype=labels.dtype, device=device)
            ], dim=-1)

            padded_tokens.append({
                "input_ids": padded_input_ids,
                "attention_mask": padded_attention_mask,
                "labels": padded_labels
            })

        # Stack the padded tensors
        batched_tokens = {
            "input_ids": torch.cat([t["input_ids"] for t in padded_tokens]),
            "attention_mask": torch.cat([t["attention_mask"] for t in padded_tokens]),
            "labels": torch.cat([t["labels"] for t in padded_tokens])
        }

        return batched_tokens
    batched_tokens = batched_tokens = pad_tokens(list_of_tokens)
    editor.to(device)
    #editor.edit(elder_tokens)
    editor.edit_batch(batched_tokens, expert_allocation)


    return editor,weights_copy
