import os
import torch
import numpy as np
from collections import  defaultdict
from verl import DataProto
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F
from verl.single_controller.base.decorator import register, Dispatch
from verl.single_controller.base import Worker
from tensordict import TensorDict

class warper_sentence_model(Worker):
    def __init__(self, path, tokenizer):
        super().__init__()
        import torch.distributed
        if not torch.distributed.is_initialized():
            torch.distributed.init_process_group(rank=self.rank, world_size=self.world_size, backend="nccl")
        print(f"warper_sentence_model current device number {torch.cuda.device_count()}")
        self.path = path
        self.model = SentenceTransformer(self.path, device="cuda")
        self.tokenizer = tokenizer

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    def cal_distance(self, data: 'DataProto'):
        """
        Compute distance between embeddings and positive example means.

        Args:
            data (DataProto): Input data containing:
                - batch['embeddings']: Tensor of shape [num_samples, embedding_dim] with input embeddings
                - batch['positive_mean']: Tensor of shape [num_samples, embedding_dim] with positive means
                - meta_info['metric']: Distance metric to use, either 'cosine' or 'euclidean' (default: 'cosine')

        Returns:
            DataProto: Output data containing:
                - batch['distance']: Tensor of shape [num_samples] with computed distances
        """
        all_embeddings = data.batch['embeddings']
        positive_means = data.batch['positive_mean']
        metric = data.meta_info.get('metric', 'cosine')
        if positive_means.device != all_embeddings.device:
            positive_means = positive_means.to(all_embeddings.device)
        if metric == 'cosine':
            dist = 1 - F.cosine_similarity(all_embeddings, positive_means, dim=-1)
        elif metric == 'euclidean':
            dist = torch.norm(all_embeddings - positive_means, dim=-1)
        else:
            raise ValueError(f"Unsupported distance metrics: {metric}")
        td = TensorDict({'distance': dist}, batch_size=[dist.shape[0]])
        return DataProto(batch=td)

    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
    def cal_avg_distance(self, data: DataProto):
        """
        Compute embeddings for input texts using the sentence transformer model.

        Args:
            data (DataProto): Input data containing:
                - non_tensor_batch['text']: List or numpy array of input text strings

        Returns:
            DataProto: Output data containing:
                - batch['embeddings']: Tensor of shape [num_texts, embedding_dim] with computed embeddings
        """
        resp_texts = data.non_tensor_batch['text']
        if isinstance(resp_texts, np.ndarray):
            resp_texts = resp_texts.tolist()
        print("begin sentence model")
        response_embeddings = self.model.encode(resp_texts, batch_size=16)
        print("end sentence model")
        if isinstance(response_embeddings, np.ndarray):
            emb_tensor = torch.tensor(response_embeddings, dtype=torch.float32)
        else:
            emb_tensor = response_embeddings.to(torch.float32)
        td = TensorDict({'embeddings': emb_tensor}, batch_size=[emb_tensor.shape[0]])
        return DataProto(batch=td)
