import clip
import torch
import numpy as np
from sentence_splitter import SentenceSplitter  

class DHMRMeasurer:
    def __init__(self, clip_model="", splitter_language='en'):

        self.clip_model, _ = clip.load(clip_model)
        self.clip_model.eval()
        
        self.splitter = SentenceSplitter(language=splitter_language)
        
    def _split_sentences(self, text):
        return self.splitter.split(text)
    
    def compute_data_hardness(self, demo_responses, gen_responses):

        demo_subsents = [self._split_sentences(r) for r in demo_responses]
        gen_subsents = [self._split_sentences(r) for r in gen_responses]
        
        sim_scores = []
        for d_subs, g_subs in zip(demo_subsents, gen_subsents):
            if not d_subs or not g_subs:
                continue
                

            d_tokens = clip.tokenize(d_subs).to(self.clip_model.device)
            g_tokens = clip.tokenize(g_subs).to(self.clip_model.device)
            
            with torch.no_grad():
                d_features = self.clip_model.encode_text(d_tokens)
                g_features = self.clip_model.encode_text(g_tokens)
            
            sim_matrix = torch.matmul(d_features, g_features.T)
            max_sim = torch.max(sim_matrix, dim=1)[0]
            
            W = torch.mean(max_sim).item()
            delta = 1 - W
            sim_scores.append(delta)
        
        deltas = np.array(sim_scores)
        sigma_delta = 1 / (1 + np.exp(-deltas))
        sigma_mean = 1 / (1 + np.exp(-deltas.mean()))
        
        alpha_D = sigma_delta / sigma_mean
        return alpha_D
    
    def compute_model_responsiveness(self, rewards, T=0.2):

        rewards = np.array(rewards)
        mean_reward = rewards.mean()
    
        squared_diff = (rewards - mean_reward)**2
        threshold = np.sort(squared_diff)[int(len(rewards)*T)]
        mask = squared_diff <= threshold
        

        filtered_rewards = rewards[mask]
        if len(filtered_rewards) == 0:
            return 0.0
        
        sigma_filtered = 1 / (1 + np.exp(-filtered_rewards.mean()))
        sigma_original = 1 / (1 + np.exp(-mean_reward))
        
        alpha_M = sigma_filtered / sigma_original
        return alpha_M