import torch
import torch.nn as nn
import json
import os
from typing import Dict, List, Tuple
from difflib import SequenceMatcher
from transformers import AutoTokenizer, AutoModel

class BGETokenizer:
    def __init__(self, model_path: str):
        with open(os.path.join(model_path, 'tokenizer.json'), 'r') as f:
            self.tokenizer_data = json.load(f)
        with open(os.path.join(model_path, 'tokenizer_config.json'), 'r') as f:
            self.config = json.load(f)
            
        print("\n=== Tokenizer  ===")
        print(f"Model vocab type: {type(self.tokenizer_data['model']['vocab'])}")
        print(f"First vocab item: {self.tokenizer_data['model']['vocab'][0] if isinstance(self.tokenizer_data['model']['vocab'], list) else 'Not a list'}")
        
        vocab = self.tokenizer_data['model']['vocab']
        if isinstance(vocab, list):
            if all(isinstance(item, str) for item in vocab):
                self.vocab = {token: idx for idx, token in enumerate(vocab)}
            elif all(isinstance(item, list) for item in vocab):
                self.vocab = {token[0]: idx for idx, token in enumerate(vocab)}
            else:
                raise ValueError(f"Unsupported vocab format: {type(vocab[0])}")
        else:
            self.vocab = vocab

        self.pad_token = self.config['pad_token']
        self.pad_token_id = self.vocab.get(self.pad_token, 0)
        self.cls_token = self.config['cls_token']
        self.cls_token_id = self.vocab.get(self.cls_token, 1)
        self.sep_token = self.config['sep_token']
        self.sep_token_id = self.vocab.get(self.sep_token, 2)
        self.unk_token = self.config['unk_token']
        self.unk_token_id = self.vocab.get(self.unk_token, 3)
        
        self.max_length = self.config['model_max_length']

    def encode(self, text: str) -> Dict[str, torch.Tensor]:
        tokens = text.lower().split()
        token_ids = [self.vocab.get(token, self.unk_token_id) for token in tokens]
        token_ids = [self.cls_token_id] + token_ids + [self.sep_token_id]

        if len(token_ids) > self.max_length:
            token_ids = token_ids[:self.max_length]
        else:
            token_ids += [self.pad_token_id] * (self.max_length - len(token_ids))
            
        attention_mask = [1] * len(token_ids)
        
        return {
            'input_ids': torch.tensor([token_ids]),
            'attention_mask': torch.tensor([attention_mask])
        }

class BGEModel(nn.Module):
    def __init__(self, model_path: str):
        super().__init__()
        with open(os.path.join(model_path, 'config.json'), 'r') as f:
            self.config = json.load(f)
            
        state_dict = torch.load(os.path.join(model_path, 'pytorch_model.bin'))
        self.model = nn.ModuleDict(state_dict)
        self.tokenizer = BGETokenizer(model_path)
        
        self.eval()
        if torch.cuda.is_available():
            self.cuda()
            
    def forward(self, input_ids, attention_mask=None):
        outputs = self.model.encoder(input_ids, attention_mask=attention_mask)
        pooled_output = outputs[0][:, 0]
        return pooled_output
            
    def encode(self, texts: List[str]) -> torch.Tensor:
        embeddings = []
        for text in texts:
            inputs = self.tokenizer.encode(text)
            if torch.cuda.is_available():
                inputs = {k: v.cuda() for k, v in inputs.items()}
            
            with torch.no_grad():
                embedding = self.forward(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask']
                )
                embeddings.append(embedding)
                
        return torch.cat(embeddings, dim=0)

def get_bge_model() -> Tuple[AutoModel, AutoTokenizer]:
    try:
        model_path = '/home/yy/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/5617a9f'
        print("\n=== Load model ===")
        print(f"Model path: {model_path}")
        
        print("Load tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        
        print("Load model...")
        model = AutoModel.from_pretrained(model_path)
        model.eval()
        
        if torch.cuda.is_available():
            print("Move model to GPU...")
            model = model.cuda()
            
        return model, tokenizer
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        raise

def compute_semantic_similarity(text1: str, text2: str, model_and_tokenizer: Tuple[AutoModel, AutoTokenizer]) -> float:
    try:
        model, tokenizer = model_and_tokenizer
        
        inputs1 = tokenizer(text1, return_tensors='pt', padding=True, truncation=True)
        inputs2 = tokenizer(text2, return_tensors='pt', padding=True, truncation=True)
        
        if torch.cuda.is_available():
            inputs1 = {k: v.cuda() for k, v in inputs1.items()}
            inputs2 = {k: v.cuda() for k, v in inputs2.items()}
        
        with torch.no_grad():
            outputs1 = model(**inputs1)
            outputs2 = model(**inputs2)
            
            embeddings1 = outputs1.last_hidden_state[:, 0]
            embeddings2 = outputs2.last_hidden_state[:, 0]
            
            similarity = torch.nn.functional.cosine_similarity(embeddings1, embeddings2)[0].cpu().item()
            
        return float(similarity)
    except Exception as e:
        print(f"Error computing semantic similarity: {str(e)}")
        print(f"Fallback to string similarity calculation")
        return compute_string_similarity(text1, text2)



def compute_string_similarity(text1: str, text2: str) -> float:
    return SequenceMatcher(None, text1, text2).ratio()

def normalize_reward_name(reward: str) -> str:
    return reward.lower().replace('_reward', '').replace('_', ' ').strip()

def get_reward_description(reward_item: str) -> str:
    return f"This is a {reward_item} reward item used to evaluate the agent's performance on {reward_item.replace('_reward','').replace('_',' ')}."

def normalize_reward_name(reward: str) -> str:
    """Normalize reward item name"""
    return reward.lower().replace('_reward', '').replace('_', ' ').strip()

def normalize_reward_items(items):

    return [normalize_reward_name(item) for item in items]

def find_similar_groups(samples):
   
    model, tokenizer = get_bge_model()
    
    n = len(samples)
    visited = set()
    similar_groups = []
    
    for i in range(n):
        if i in visited:
            continue
            
        current_group = {i}
        visited.add(i)
        
        sample_i_rewards_norm = [normalize_reward_name(r) for r in samples[i]['reward_items']]
        sample_i_text = " ".join(sorted(sample_i_rewards_norm))
        sample_i_semantic = " ".join(sorted([get_reward_description(r) for r in samples[i]['reward_items']]))
        
        for j in range(i + 1, n):
            if j in visited:
                continue
                
            sample_j_rewards_norm = [normalize_reward_name(r) for r in samples[j]['reward_items']]
            sample_j_text = " ".join(sorted(sample_j_rewards_norm))
            sample_j_semantic = " ".join(sorted([get_reward_description(r) for r in samples[j]['reward_items']]))
            
            text_similarity = compute_string_similarity(sample_i_text, sample_j_text)
            semantic_similarity = compute_semantic_similarity(sample_i_semantic, sample_j_semantic, (model, tokenizer))
            
            
            similarity = max(text_similarity, semantic_similarity)
            
            if similarity >= 0.95:
                current_group.add(j)
                visited.add(j)
        
        if len(current_group) > 1:
            similar_groups.append(current_group)
            print(f"\nSimilar reward group (group size: {len(current_group)}):")
            for idx in current_group:
                print(f"Sample {idx}: {', '.join(samples[idx]['reward_items'])}")
            
            print("\nSimilarity matrix (text | semantic):")
            group_list = sorted(list(current_group))
            for idx1 in group_list:
                text_similarities = []
                semantic_similarities = []
                s1_text = " ".join(sorted([normalize_reward_name(r) for r in samples[idx1]['reward_items']]))
                s1_semantic = " ".join(sorted([get_reward_description(r) for r in samples[idx1]['reward_items']]))
                
                for idx2 in group_list:
                    s2_text = " ".join(sorted([normalize_reward_name(r) for r in samples[idx2]['reward_items']]))
                    s2_semantic = " ".join(sorted([get_reward_description(r) for r in samples[idx2]['reward_items']]))
                    
                    text_sim = compute_string_similarity(s1_text, s2_text)
                    semantic_sim = compute_semantic_similarity(s1_semantic, s2_semantic, (model, tokenizer))
                    
                    text_similarities.append(f"{text_sim:.3f}")
                    semantic_similarities.append(f"{semantic_sim:.3f}")
                
                print(f"Sample {idx1} (text): {' '.join(text_similarities)}")
                print(f"Sample {idx1} (semantic): {' '.join(semantic_similarities)}")
    
    return similar_groups