"""
enviroment1:
ʹmath rewardͲһ³ͷ
"""
import contextlib
import io
import logging
from typing import Any, Optional, TypedDict

import ray
import torch
from math_verify.errors import TimeoutException
from math_verify.metric import math_metric
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig

from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES
from nemo_rl.environments.interfaces import (
    EnvironmentInterface,
    EnvironmentReturn,
)
from nemo_rl.environments.metrics import (
    calculate_pass_rate_per_prompt,
)
from nemo_rl.environments.utils import chunk_list_to_workers




import torch
from sentence_transformers import SentenceTransformer, util
# 滻ΪTF-IDFص
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from itertools import islice
import gzip
import numpy as np
import jsonlines
from tqdm import tqdm
from typing import List
# ģΪȫֱأÿεö¼



def cal_advantages(similarity_reward_list: list[float], G):
    reward = torch.tensor(similarity_reward_list, dtype=torch.float32)  # shape: [512]
    reward = reward.view(-1, G)  # shape: [32, 16]
    # print(reward)
    # print(reward.shape)
    ### norm
    mean = reward.mean(dim=1, keepdim=True)  # shape: [32, 1]
    std = reward.std(dim=1, keepdim=True, unbiased=False)  # shape: [32, 1]
    eps = 1e-8  # ֹ
    z_scores = (reward - mean) / (std + eps)
    advantages = z_scores.view(-1, G)  # shape: [32, 16]

    return advantages



class MathEnvConfig(TypedDict):
    num_workers: int
    stop_strings: Optional[list[str]]  # Default stop strings for this env
    reward_type: str 


@contextlib.contextmanager
def _mute_output():
    devnull_out, devnull_err = io.StringIO(), io.StringIO()
    with (
        contextlib.redirect_stdout(devnull_out),
        contextlib.redirect_stderr(devnull_err),
    ):
        yield


@ray.remote
class HFVerifyWorker:
    def __init__(self) -> None:
        logging.getLogger("math_verify").setLevel(logging.CRITICAL)

        # Use Latex and plain math extraction from predictions
        # https://github.com/huggingface/Math-Verify?tab=readme-ov-file#extraction-targets
        self.verify_func = math_metric(
            gold_extraction_target=(LatexExtractionConfig(),),
            pred_extraction_target=(
                ExprExtractionConfig(),
                LatexExtractionConfig(),
            ),
        )

    def verify(
        self, pred_responses: list[str], ground_truths: list[str]
    ) -> list[float]:
        """Verify the correctness of the predicted responses against the ground truth.

        Args:
            pred_responses: list[str]. The predicted responses from the LLM.
            ground_truths: list[str]. The ground truth responses.

        Returns:
            list[float]. The rewards for each predicted response.
        """
        results = []
        for response, ground_truth in zip(pred_responses, ground_truths):
            try:
                ground_truth_parsable = "\\boxed{" + ground_truth + "}"
                with _mute_output():
                    try:
                        ret_score, _ = self.verify_func(
                            [ground_truth_parsable], [response]
                        )
                    # It's possible to emit a TimeoutException and that wouldn't be caught since
                    # it actually subclasses from BaseException and math-verify itself does not
                    # to catch it.
                    except (Exception, TimeoutException):
                        ret_score = 0.0

                results.append(float(ret_score))
            except Exception:
                results.append(0.0)
        return results


class MathEnvironmentMetadata(TypedDict):
    ground_truth: str


@ray.remote(max_restarts=-1, max_task_retries=-1)   
class MathEnvironment(EnvironmentInterface):
    def __init__(self, cfg: MathEnvConfig):
        self.cfg = cfg
        self.num_workers = cfg["num_workers"]
        self.workers = [
            HFVerifyWorker.options(  # type: ignore # (decorated with @ray.remote)
                runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM}
            ).remote()
            for _ in range(self.num_workers)
        ]
        self._model = SentenceTransformer('/home/cwy/LLM/all-MiniLM-L6-v2').to("cpu")
        # self._model = SentenceTransformer('/home/cwy/LLM/Qwen3-Embedding-0.6B').to("cpu")
        # ʹTF-IDF滻Sentence Transformer
        # self._vectorizer = TfidfVectorizer()
        print(f"cfg: {cfg}")


    def compute_similarity_to_center(self, sentences: List[str]) -> List[float]:
        if not sentences:
            return []

        cos_similarities = [0.]*len(sentences)

        # embedding model
        embeddings = self._model.encode(sentences, convert_to_tensor=True)  # shape: [n, 768]
        center = torch.mean(embeddings, dim=0)  # shape: [768]
        cos_similarities = []
        for i in range(embeddings.shape[0]):
            sim = util.cos_sim(embeddings[i], center).item()
            cos_similarities.append(sim)
            
        # ʹTF-IDFı
        # tfidf_matrix = self._vectorizer.fit_transform(sentences)
        # center = tfidf_matrix.mean(axis=0)
        # center=np.array(center)
        # cos_similarities = cosine_similarity(tfidf_matrix, center)
        # cos_similarities = [float(score[0]) for score in cos_similarities]

        return cos_similarities

    def compute_z_score(self, input):
        # תΪ numpy 
        sim_array = np.array(input)
        # Z-score ׼(x - mean) / std
        mean = np.mean(sim_array)
        std = np.std(sim_array) + 1e-9
        z_scores = (sim_array - mean) / std

        return z_scores


    def duplicate_sentences_ngram(self, text, n=10):
        words = str(text).split()
        ngrams = zip(*[islice(words, i, None) for i in range(n)])
        seen = {}

        for gram in ngrams:
            phrase = ' '.join(gram)
            if phrase in seen.keys():
                # print(phrase)
                # return True
                seen[phrase] += 1
            # seen.add(phrase)
            else:
                seen[phrase] = 1
        if len(list(seen.values())):
            max_ = max(seen.values())
        else:
            max_=0
        return max_
    def cal_compress_ratio(self, text):
        # Խ˵ϢԽ, 
        original_size = len(text.encode('utf-8'))
        compressed_size = len(gzip.compress(text.encode('utf-8')))
        return compressed_size / original_size 

    def shutdown(self) -> None:
        # shutdown all workers
        for worker in self.workers:
            ray.kill(worker)

    def step(  # type: ignore[override]
        self,
        message_log_batch: list[list[dict[str, str]]],
        metadata: list[MathEnvironmentMetadata],
    ) -> EnvironmentReturn:
        """Runs a step in the math environment.

        Args:
            message_log: list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM.
            metadata: list[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness.

        Returns:
            EnvironmentReturn: A tuple containing:
                - list[dict[str, str]]: Observations/responses batch
                - list[dict]: Updated metadata
                - list[str]: Next stop strings for the next turn
                - Tensor: Rewards tensor
                - Tensor: Done flags tensor
        """
        user_prompt_batch = []
        for conversation in message_log_batch:
            user_prompts = [
                interaction["content"]
                for interaction in conversation
                if interaction["role"] == "user"
            ]
            user_prompt_batch.append("".join(user_prompts))


        # Extract the assistant's responses from the message history
        # Each message list should have at least one assistant response
        assistant_response_batch = []
        for conversation in message_log_batch:
            assistant_responses = [
                interaction["content"]
                for interaction in conversation
                if interaction["role"] == "assistant"
            ]
            assistant_response_batch.append("".join(assistant_responses))

        # ȥ</think>Լݡ
        # assistant_response_batch = [str(x).split("</think>")[0] for x in assistant_response_batch]
        

        # 1. assistant_response_batchuser_prompt_batchָlist[list] (num_prompts_per_step, num_generations_per_prompt)
        # num_generations_per_prompt = sum([x==user_prompt_batch[0] for x in user_prompt_batch])
        # num_prompts_per_step = len(assistant_response_batch) // num_generations_per_prompt
        # print(f"num_generations_per_prompt: {num_generations_per_prompt}")
        # print(f"num_prompts_per_step: {num_prompts_per_step}")
        if len(assistant_response_batch) == 512:
            num_generations_per_prompt = 16
        elif len(assistant_response_batch) == 300:
            num_generations_per_prompt = 10
        num_prompts_per_step = len(assistant_response_batch) // num_generations_per_prompt
        # print(f"num_generations_per_prompt: {num_generations_per_prompt}")
        # print(f"num_prompts_per_step: {num_prompts_per_step}")
        assistant_response_nested = [
            assistant_response_batch[i * num_generations_per_prompt : (i + 1) * num_generations_per_prompt]
            for i in range(num_prompts_per_step)
        ]
        similarity_reward_nested=[self.compute_similarity_to_center(x) for x in assistant_response_nested]
        similarity_reward = [x for y in similarity_reward_nested for x in y]
        

        



        ground_truths = [g["ground_truth"] for g in metadata]

        chunked_assistant_response_batch = chunk_list_to_workers(
            assistant_response_batch, self.num_workers
        )
        chunked_ground_truths = chunk_list_to_workers(ground_truths, self.num_workers)

        # # Process each chunk in parallel
        futures = [
            self.workers[i].verify.remote(chunk, ground_truth_chunk)
            for i, (chunk, ground_truth_chunk) in enumerate(
                zip(chunked_assistant_response_batch, chunked_ground_truths)
            )
        ]

        results = ray.get(futures)

        # flatten the results
        results = [item for sublist in results for item in sublist]
        observations = [
            {
                "role": "environment",
                "content": "Environment: correct"
                if result
                else "Environment: incorrect",
            }
            for result in results
        ]

        # create a tensor of rewards and done flags
        # rewards = torch.tensor(results).cpu()
        # correctness_rewards = torch.tensor(results).cpu()
        # similarity_reward = torch.tensor(results).cpu()

        correctness_rewards = torch.tensor(results).cpu()
        similarity_reward = torch.tensor(similarity_reward).cpu()
        

        similarity_reward_raw=similarity_reward.clone()
        mask = torch.ones_like(similarity_reward, dtype=torch.bool)
        # # һȥ top 10% ƶȵλ
        # # k = int(0.1 * similarity_reward.shape[-1])
        # # _, topk_indices = torch.topk(similarity_reward, k, dim=-1)
        # # mask.scatter_(dim=-1, index=topk_indices, value=False)
        # ڶȥƶ > 0.5 λ
        # high_similarity_mask = (similarity_reward > 0.8)
        # mask &= ~high_similarity_mask  #  0.5 λҲε
        # # ȥظȹߵ
        # # duplicate_mask = torch.tensor([self.duplicate_sentences_ngram(x) >= 3 for x in assistant_response_batch], dtype=torch.bool)
        # # Ĳȥѹʹ͵ġ
        # # duplicate_mask = torch.tensor([self.cal_compress_ratio(x) <= 0.4 for x in assistant_response_batch], dtype=torch.bool)
        # # mask &= ~duplicate_mask # ظҲmask
        # valid_rewards = similarity_reward[mask]
        # ʹ mask ȡЧ reward
        # if valid_rewards.numel() > 0:
        #     mean_val = valid_rewards.mean()
        #     similarity_reward = torch.where(mask, similarity_reward, mean_val)
        # else:
        #     similarity_reward = torch.zeros_like(similarity_reward)

        #  similarity_reward дƽֵԪأȫΪƽֵЧֻڸѵhttps://arxiv.org/pdf/2506.01347ʵЧǺܺá
        # ƶ̫Ϊ0
        similarity_reward = cal_advantages(similarity_reward, G=num_generations_per_prompt)   # (32, 16)
        N = 4  # Ҫȡ 2 
        nth_value, nth_idx = torch.kthvalue(similarity_reward, N, dim=-1, keepdim=True) # (32, 1)
        print(nth_value)
        similarity_reward = torch.where(    
            similarity_reward > 0,
            0,
            similarity_reward
        )
        similarity_reward=similarity_reward.view(-1, 1)
        # ֻһ򼸸ѵʹģֻѧϰļ


        # ʹhas_duplicate_sentences_ngramжǷظʹƶΪƣظƹ㡣
        # duplicate_mask = torch.tensor([self.has_duplicate_sentences_ngram(x) for x in assistant_response_batch], dtype=torch.bool)
        # similarity_reward = torch.where(
        #     duplicate_mask,
        #     torch.zeros_like(similarity_reward),
        #     similarity_reward
        # )



        # k = int(0.1 * similarity_reward.shape[-1])  # ƶߵİٷ֮100
        # _, topk_indices = torch.topk(similarity_reward, k)
        # similarity_reward[topk_indices] = 0    
        # similarity_reward[similarity_reward > 0.5] = 0  # ƶȴ0.60
        # Ӧ0ʮҲӦʹչһĽΪ0ǵⲿֲݶȼ㡣Ӧȡʣµƽֵ


        if self.cfg['reward_type'] == "(1+sim)/2":
            rewards = (1 + similarity_reward)/2
        elif self.cfg['reward_type'] == "sim":    
            rewards = similarity_reward
        elif self.cfg['reward_type'] == "(1-sim)/2":    
            rewards = (1 - similarity_reward)/2      
        elif self.cfg['reward_type'] == "(1+sim)/2+correct":
            rewards = correctness_rewards + (1 + similarity_reward)/2   
        elif self.cfg['reward_type'] == "(1-sim)/2+correct":
            rewards = correctness_rewards + (1 - similarity_reward)/2   
        elif self.cfg['reward_type'] == "(1+sim)/2*correct":      # ȷԽԽ
            rewards = correctness_rewards * (1 + similarity_reward)/2     
        elif self.cfg['reward_type'] == "(1-sim)/2*correct":      # ȷԽԽ
            rewards = correctness_rewards * (1 - similarity_reward)/2     
        else:
            raise NotImplementedError
        # rewards=correctness_rewards+similarity_reward # 
        # rewards = correctness_rewards * similarity_reward # 
        # rewards = similarity_reward # ֻʹƶΪ
        done = torch.ones_like(rewards).cpu()

        next_stop_strings = [None] * len(message_log_batch)

        return EnvironmentReturn(
            observations=observations,
            metadata=metadata,
            next_stop_strings=next_stop_strings,
            rewards=rewards,
            correctness_rewards=correctness_rewards,
            similarity_reward=similarity_reward_raw,
            terminateds=done,
        )

    def global_post_process_and_metrics(
        self, batch: BatchedDataDict[Any]
    ) -> tuple[BatchedDataDict[Any], dict[str, float | int]]:
        """Computes metrics for this environment given a global rollout batch.

        Every rank will run this function, so you're free to use distributed
        calculations if you'd prefer for heavy metrics.
        """
        batch["rewards"] = (
            batch["rewards"] * batch["is_end"]
        )  # set a reward of 0 for any incorrectly ended sequences
        if (batch["rewards"] == 1).float().sum() > 0:
            correct_solution_generation_lengths = (
                (batch["generation_lengths"] - batch["prompt_lengths"])[
                    batch["rewards"] == 1
                ]
                .float()
                .mean()
                .item()
            )
        else:
            correct_solution_generation_lengths = 0

        metrics = {
            # "table": table, TODO @sahilj WIP
            "accuracy": batch["rewards"].mean().item(),
            "pass@samples_per_prompt": calculate_pass_rate_per_prompt(
                batch["text"], batch["rewards"]
            ),
            # "correctness_rewards": batch["correctness_rewards"].mean().item(),
            # "similarity_reward": batch["similarity_reward"].mean().item(),
            "fraction_of_samples_properly_ended": batch["is_end"].float().mean().item(),
            "num_problems_in_batch": batch["is_end"].shape[0],
            "generation_lengths": batch["generation_lengths"].float().mean().item(),
            "prompt_lengths": batch["prompt_lengths"].float().mean().item(),
            "correct_solution_generation_lengths": correct_solution_generation_lengths,
        }

        return batch, metrics
