import warnings
import numpy as np
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag
from nltk.corpus import stopwords
from functools import lru_cache
import math
from collections import Counter
import nltk
from sentence_transformers import SentenceTransformer
import torch
from tqdm import tqdm

# Preload NLTK data
nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)
nltk.download('stopwords', quiet=True)

# Create a set of English stopwords for efficient lookup
STOPWORDS = set(stopwords.words('english'))

class Agent:
    def __init__(self, name, profile, goal=""):
        self.name = name
        self.profile = profile

class MultiAgentEnvironment:
    def __init__(self, agents, task):
        self.agents = agents
        self.task = task

class ModelSingleton:
    _instance = None

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = SentenceTransformer('paraphrase-MiniLM-L6-v2')
        return cls._instance

class AgentAnalyzer:
    def __init__(self, multi_agent_environment):
        self.multi_agent_environment = multi_agent_environment 
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.embedding_model = None
        self.embedding_cache = {}
        self._initialize_model()
        self._init_prototypes()
        self._precompute_agent_data()
    
    def _suppress_specific_warnings(self):
        warnings.filterwarnings("ignore", message=".*clean_up_tokenization_spaces.*")

    def _initialize_model(self):
        if self.embedding_model is None:
            self.embedding_model = ModelSingleton.get_instance().to(self.device)

    def _init_prototypes(self):
        prototype_words = {
            'skill': "skill expertise proficiency competence",
            'complexity': "complex challenging difficult advanced",
            'simplicity': "simple easy straightforward basic",
            'capability': "expert experienced skilled proficient advanced",
            'limitation': "beginner junior learning novice"
        }
        self.prototypes = {key: self.get_embedding(words) for key, words in prototype_words.items()}

    def _precompute_agent_data(self):
        self.agent_profiles = [agent.profile for agent in self.multi_agent_environment.agents]
        self.agent_nouns = [self._extract_nouns(profile) for profile in self.agent_profiles]
        self.agent_embeddings = None

    def _compute_agent_embeddings(self):
        if self.agent_embeddings is None:
            self.agent_embeddings = [self.get_embedding(profile) for profile in self.agent_profiles]
        return self.agent_embeddings

    @lru_cache(maxsize=128)
    def get_embedding(self, text):
        if text in self.embedding_cache:
            return self.embedding_cache[text]
        self._initialize_model()
        with torch.no_grad():
            embedding = self.embedding_model.encode(text, convert_to_tensor=True).to(self.device)
        self.embedding_cache[text] = embedding
        return embedding
    
    @staticmethod
    def _extract_nouns(text):
        tokens = word_tokenize(text)
        pos_tags = pos_tag(tokens)
        return " ".join([word.lower() for word, pos in pos_tags if pos.startswith('NN') and word.lower() not in STOPWORDS])

    def calculate_role_differentiation(self):
        embeddings = torch.stack(self._compute_agent_embeddings())
        n = len(embeddings)
        total_dissimilarity = 0
        count = 0
        for i in range(n):
            for j in range(i+1, n):
                similarity = torch.nn.functional.cosine_similarity(embeddings[i].unsqueeze(0), embeddings[j].unsqueeze(0))
                total_dissimilarity += 1 - similarity.item()
                count += 1
        
        avg_dissimilarity = total_dissimilarity / count if count > 0 else 0
        
        k, x0 = 10, 0.6
        return 1 / (1 + math.exp(-k * (avg_dissimilarity - x0)))

    def calculate_role_clarity(self, profile):
        profile_embedding = self.get_embedding(profile)
        
        dep_score = self._calculate_dep_score(profile)
        entropy_score = self._calculate_entropy(profile)
        skill_score = self._calculate_skill_score(profile_embedding)
        
        return {
            "profile": profile,
            "dep_score": dep_score,
            "entropy_score": entropy_score,
            "skill_score": skill_score,
            "total_score": (dep_score + entropy_score + skill_score) / 3
        }

    def _calculate_skill_score(self, profile_embedding):
        skill_similarities = 1 - torch.nn.functional.cosine_similarity(profile_embedding.unsqueeze(0), self.prototypes['skill'].unsqueeze(0))
        return skill_similarities.item()

    @staticmethod
    def _calculate_dep_score(text):
        sentences = nltk.sent_tokenize(text)
        return min(np.mean([len(word_tokenize(sent)) for sent in sentences]) / 10, 1.0)

    @staticmethod
    def _calculate_entropy(text):
        tokens = word_tokenize(text.lower())
        word_freq = Counter(tokens)
        total_words = sum(word_freq.values())
        entropy = -sum((count / total_words) * math.log2(count / total_words) for count in word_freq.values())
        return min(entropy / 4, 1.0)

    def calculate_overall_task_role_alignment(self):
        task = self.multi_agent_environment.task
        task_embedding = self.get_embedding(task)
        
        similarities = torch.nn.functional.cosine_similarity(task_embedding.unsqueeze(0), torch.stack(self.agent_embeddings))
        profile_similarity_score = torch.mean(similarities).item()
    
        task_complexity = self.assess_task_complexity(task)
        team_capabilities = [self.assess_agent_capability(profile) for profile in self.agent_profiles]
        team_capability = np.mean(team_capabilities)
        capability_match_score = 1 - abs(task_complexity - team_capability)

        return 0.7 * profile_similarity_score + 0.3 * capability_match_score

    def assess_task_complexity(self, text):
        text_embedding = self.get_embedding(text)
        complexity_similarity = 1 - torch.nn.functional.cosine_similarity(text_embedding.unsqueeze(0), self.prototypes['complexity'].unsqueeze(0)).item()
        simplicity_similarity = 1 - torch.nn.functional.cosine_similarity(text_embedding.unsqueeze(0), self.prototypes['simplicity'].unsqueeze(0)).item()
        return (complexity_similarity - simplicity_similarity + 1) / 2

    def assess_agent_capability(self, text):
        text_embedding = self.get_embedding(text)
        capability_score = 1 - torch.nn.functional.cosine_similarity(text_embedding.unsqueeze(0), self.prototypes['capability'].unsqueeze(0)).item()
        limitation_score = 1 - torch.nn.functional.cosine_similarity(text_embedding.unsqueeze(0), self.prototypes['limitation'].unsqueeze(0)).item()
        return (capability_score - limitation_score + 1) / 2

    def analyze_role_clarity(self):
        profiles = [agent.profile for agent in self.multi_agent_environment.agents]
        batch_size = 10  # Adjust based on your GPU memory
        results = {}
        
        for i in range(0, len(profiles), batch_size):
            batch = profiles[i:i+batch_size]
            batch_embeddings = self.get_embedding(batch)
            
            for j, embedding in enumerate(batch_embeddings):
                agent = self.multi_agent_environment.agents[i+j]
                profile = profiles[i+j]
                
                dep_score = self._calculate_dep_score(profile)
                entropy_score = self._calculate_entropy(profile)
                skill_score = self._calculate_skill_score(embedding)
                
                results[agent.name] = {
                    "profile": profile,
                    "dep_score": dep_score,
                    "entropy_score": entropy_score,
                    "skill_score": skill_score,
                    "total_score": (dep_score + entropy_score + skill_score) / 3
                }
            
            # Clear cache after each batch
            self.clear_cuda_memory()
        
        return results

    def analyze_agents(self):
        results = {
            'overall_role_differentiation': self.calculate_role_differentiation(),
            'overall_task_role_alignment': self.calculate_overall_task_role_alignment(),
            'role_clarity': self.analyze_role_clarity()
        }
        self.clear_cache()
        return results

    def clear_cache(self):
        self.get_embedding.cache_clear()  # Clear the LRU cache
        self.agent_embeddings = None  # Clear precomputed embeddings
        torch.cuda.empty_cache()
    
    def clear_cuda_memory(self):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()