from rouge_score import rouge_scorer


def soft_clip(x: float, thres=5, coeff=0.1):
    if x < thres:
        return x
    else:
        return thres + coeff * (x - thres)


def rouge_to_reward(rouge, n=1):
    return rouge / (1 - rouge) ** n


class RougeEvaluator:
    def __init__(self):
        self.scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    def evaluate(self, generated_text, reference_text):
        scores = self.scorer.score(generated_text, reference_text)
        return scores



class CompressRM:
    def __init__(self, alpha=0.5, beta=0.5, theta=1, n=1):
        """
        RM for Compress Task
        mode can be 'Compressor' or 'Expander' or 'Rearranger'
        """

        self.rouge = RougeEvaluator()
        self.alpha = alpha
        self.beta = beta
        self.theta = theta
        self.n = n

    def __call__(self,
                 ref_q_ts,
                 com_q_ts,
                 ref_qs,
                 com_qs,
                 ref_rs,
                 com_rs,
                 return_details=False):
        """
        Calculates rewards for compression and ROUGE scores for multiple text pairs.
        NOTE: index_reward is the origin "compress_reward" in this class!
        Args:
            ref_q_ts: list of reference questions tensors
            com_q_ts: list of compressed questions tensors
            ref_rs: list of reference responses
            com_rs: list of compressed responses
            return_details: whether to return details (compress ratio and ROUGE score) or not
        """
        rewards = []
        cr_rewards = []  # theta
        response_rouges = []  # alpha
        response_rouge1s = []  # alpha
        response_rouge2s = []  # alpha
        response_rougeLs = []  # alpha
        query_rouges = []  # beta
        for ref_q_t, com_q_t, ref_q, com_q, ref_r, com_r in zip(ref_q_ts, com_q_ts, ref_qs, com_qs, ref_rs, com_rs):
            # Mode-specific Index reward
            compress_ratio = len(ref_q_t) / len(com_q_t)
            # clip compress reward
            cr_reward = soft_clip(compress_ratio, thres=5, coeff=0.1)

            # Response and Query rouges
            query_similarity = self.rouge.evaluate(com_q, ref_q)['rouge1'].fmeasure
            response_rouge1 = self.rouge.evaluate(com_r, ref_r)['rouge1'].fmeasure
            response_rouge2 = self.rouge.evaluate(com_r, ref_r)['rouge2'].fmeasure
            response_rougeL = self.rouge.evaluate(com_r, ref_r)['rougeL'].fmeasure
            response_similarity = (response_rouge1 + response_rouge2 + response_rougeL) / 3

            query_similarity_reward = rouge_to_reward(query_similarity, self.n)
            response_similarity_reward = rouge_to_reward(response_similarity, self.n)

            reward = ((cr_reward * self.theta)
                      + (response_similarity_reward * self.alpha)
                      + (query_similarity_reward * self.beta))
            rewards.append(reward)

            cr_rewards.append(cr_reward)
            response_rouge1s.append(response_rouge1)
            response_rouge2s.append(response_rouge2)
            response_rougeLs.append(response_rougeL)
            query_rouges.append(query_similarity)

        if return_details:
            return rewards, cr_rewards,  response_rouge1s, response_rouge2s, response_rougeLs, query_rouges
        else:
            return rewards
        
    def rouge_between_two_LLMs(self, com_q_1, com_q_2):
        rouge_rewards = []
        for com_1, com_2 in zip(com_q_1, com_q_2):
            rouge_scores = self.rouge.evaluate(com_1, com_2)['rougeL'].fmeasure
            rouge_rewards.append(rouge_scores * self.beta)
        return rouge_rewards



class CompressRM_old:
    def __init__(self):
        self.rouge = RougeEvaluator()
        self.gamma = 0.5

    def __call__(self, ref_qs, com_qs, ref_rs, com_rs, return_details=False):
        """
        Calculates rewards for compression and ROUGE scores for multiple text pairs.
        """
        rewards = []
        crs = []
        rouges = []
        for ref_q, com_q, ref_r, com_r in zip(ref_qs, com_qs, ref_rs, com_rs):
            compress_ratio = len(ref_q) / len(com_q)

            # clip compress reward
            if compress_ratio > 5:
                compress_reward = 5
            elif compress_ratio < 1:
                compress_reward = 0
            else:
                compress_reward = compress_ratio

            rouge_scores = self.rouge.evaluate(com_r, ref_r)['rougeL'].fmeasure

            reward = compress_reward * self.gamma + rouge_scores
            rewards.append(reward)

            crs.append(compress_ratio)
            rouges.append(rouge_scores)

        if return_details:
            return rewards, crs, rouges
        else:
            return rewards


if __name__ == "__main__":
    evaluator = RougeEvaluator()
    scores = evaluator.evaluate("This is a generated summary.", "This is a reference summary.")
    print(scores)  # Output: {'rouge1': Score(precision=0.8, recall=0.8, fmeasure=0.8000000000000002)...}

    # Access individual scores:
    rouge1_f1 = scores['rouge1'].fmeasure
    rougeL_recall = scores['rougeL'].recall
