from concurrent.futures import ThreadPoolExecutor
from datetime import datetime

from ares.trainers.base_trainer import BaseTrainer, register_trainer
from ares.trainers.components.data.data_processor import DataProcessor
from ares.trainers.components.data_context import dump_data_context, make_batch
from ares.trainers.components.evaluate.online_sample_evaluator import OnlineSampleEvaluator
from ares.trainers.components.sample.online_offline_sampler import OnlineOfflineSampler
from ares.trainers.components.timeline.timeline_drawer import TimelineDrawer, timeline_decorator
from ares.trainers.components.trainer_component_factory import init_grader, init_advantage_calculator, init_sampler
from ares.trainers.components.grader.prm_grader import PRMGrader
from ares.trainers.components.roles.default_roles_manager import DefaultRolesManager
from ares.trainers.components.runtime.default_runtime import DefaultRuntime
from ares.utils.logger import logger
from ares.utils.utils import jaccard_similarity


@register_trainer
class GRPOTrainer(BaseTrainer):
    def __init__(self, config):
        super().__init__(config)
        self.runtime = DefaultRuntime(config)
        self.roles = DefaultRolesManager(config, self.runtime.start_round)
        if self.config.train.pre_generate_data:
            # Pre-generation (using current round actor for next round sampling to improve training efficiency) 
            # will load data one round in advance, need to compensate for one more round
            self.data_processor = DataProcessor(config, self.roles.get_all_roles(), self.runtime.start_round + 1)
        else:
            self.data_processor = DataProcessor(config, self.roles.get_all_roles(), self.runtime.start_round)
        self.sampler = init_sampler(config, self.roles.get_all_roles(), self.tokenizer, self.runtime)
        self.grader = init_grader(config, self.roles.get_all_roles())
        self.advantage_calculator = init_advantage_calculator(config)
        self.evaluator = OnlineSampleEvaluator(config, self.roles.get_all_roles(), self.writer, self.tokenizer)
        self.use_prm = isinstance(self.grader, PRMGrader)
        self.thread_pool = ThreadPoolExecutor(max_workers=len(config.roles.keys()))
        self.pre_generate_data = self.config.train.pre_generate_data
        self.timeline_drawer = TimelineDrawer.get_instance()
        self.timeline_drawer.activate()
        if config.train.epochs is not None:
            round_num_per_epoch = self.data_processor.train_loader_len // config.train.query_size
            self.config.train.rounds = round_num_per_epoch * config.train.epochs
            self.config.train.ckpt_interval = round_num_per_epoch
            self.config.train.eval_interval = round_num_per_epoch
            logger.info(f'Recalculated rounds based on epochs: {self.config.train.rounds}')

    def learn(self):
        self.trainer_stats.start()
        if self.runtime.start_round == 0:
            self.evaluator.evaluate(0)

        for current_round in range(self.runtime.start_round, self.config.train.rounds):
            logger.info(f'learn one round | Round {current_round}')
            self.runtime.update_current_round(current_round)
            actor_updated, stats = self.learn_one_round()
            self.roles.update_roles(actor_updated, current_round)
            if (
                current_round + 1
            ) % self.config.train.ckpt_interval == 0 or current_round + 1 == self.config.train.rounds:
                self.__save_for_recovery()

            if (current_round + 1) % self.config.train.eval_interval == 0:
                self.evaluator.evaluate(current_round + 1)

            if (current_round + 1) % 1 == 0 or current_round + 1 == self.config.train.rounds:
                self.timeline_drawer.show_and_clear_events(self.config.train.output_dir, current_round)
            trainer_stats = self.trainer_stats.get_stats()
            stats.update(trainer_stats)
            self.__report_stats(stats, current_round)

    def learn_one_round(self):
        current_round = self.runtime.current_round
        self.runtime.monitor_stats_dict = {}
        # Sampling
        if self.pre_generate_data:
            # Current round data has been pre-sampled in the previous round
            if current_round == 0:
                # Round 0 has no pre-sampled data, need to use current round actor to generate current round data
                origin_data_list = self.data_processor.load_train_batch_data()
                self.runtime.pre_generate_query_data_list = self.sampler.sample(current_round, origin_data_list)
            query_data_list = self.runtime.pre_generate_query_data_list
            # Use current round actor to pre-sample for next round
            origin_data_list = self.data_processor.load_train_batch_data()
            sample_future = self.thread_pool.submit(self.sampler.sample, current_round + 1, origin_data_list)
        else:
            # Use current round actor to complete current round sampling
            origin_data_list = self.data_processor.load_train_batch_data()
            query_data_list = self.sampler.sample(current_round, origin_data_list)
        # Asynchronous inference
        infer_start_time = datetime.now()
        infer_input_list = make_batch(query_data_list, self.config.train.micro_batch_size, self.tokenizer.pad_token_id)
        policy_old_logprobs_futures = self.roles.actor.infer(current_round, infer_input_list)
        ref_logprobs_futures = self.roles.sft.infer(current_round, infer_input_list)
        # Asynchronous grading
        grade_future = None
        if isinstance(self.sampler, OnlineOfflineSampler):
            # Online retry process has already used grader for scoring, no need to re-score here
            grade_future = self.thread_pool.submit(self.grader.grade, current_round, query_data_list)
        # Wait for asynchronous processes to complete
        policy_old_logprobs_list = policy_old_logprobs_futures.get_all_dp_result()
        ref_logprobs = ref_logprobs_futures.get_all_dp_result()
        infer_end_time = datetime.now()
        self.timeline_drawer.record_event(
            label_name='actor', start_time=infer_start_time, end_time=infer_end_time, annotation='infer'
        )
        self.timeline_drawer.record_event(
            label_name='sft', start_time=infer_start_time, end_time=infer_end_time, annotation='infer'
        )
        if grade_future:
            grade_future.result()
        # Calculate advantage
        self.advantage_calculator.compute_advantage(query_data_list)
        # Save results
        dump_data_context(current_round, query_data_list, self.config.train.output_dir)
        # Build training input
        train_input_list = make_batch(query_data_list, self.config.train.micro_batch_size, self.tokenizer.pad_token_id)
        for i in range(len(train_input_list)):
            train_input_list[i]['policy_old_logprobs'] = policy_old_logprobs_list[i]
            train_input_list[i]['ref_logprobs'] = ref_logprobs[i]
        # Training
        num_batches = self.config.train.global_batch_size // self.config.train.micro_batch_size
        split_input_list = [train_input_list[i : i + num_batches] for i in range(0, len(train_input_list), num_batches)]
        for sub_input_list in split_input_list:
            train_start_time = datetime.now()
            train_future = self.roles.actor.train(current_round, sub_input_list)
            updated, actor_stats = train_future.get_dp0_result()
            train_end_time = datetime.now()
            self.timeline_drawer.record_event(
                label_name='actor', start_time=train_start_time, end_time=train_end_time, annotation='train'
            )

        extra_stats = self.complete_stats(query_data_list)
        actor_stats.update(extra_stats)
        self.runtime.query_data_list = query_data_list
        if self.pre_generate_data:
            # Wait for pre-sampling of next round to complete
            self.runtime.pre_generate_query_data_list = sample_future.result()
        return updated, actor_stats

    def complete_stats(self, query_data_list):
        # Calculate statistics: average valid token length, step count, ngram, mean/std reward across different queries
        # Answer token length (excluding padding)
        div_n_gram = self.config.train.div_n_gram
        answer_valid_token_length_list = []
        step_num_list = []
        div_ratio_list = []
        mean_reward_list = []
        std_reward_list = []
        pass_k_list = []
        all_true_sample_cnt = []
        useful_sample_cnt = []
        mean_adv_list = []
        max_len_list = []
        response_similarity_list_top1000 = []
        response_similarity_list_top5000 = []
        response_similarity_list = []
        for query_data in query_data_list:
            mean_reward = query_data.mean_reward
            std_reward = query_data.std_reward
            mean_reward_list.append(mean_reward if mean_reward else 0)
            std_reward_list.append(std_reward if std_reward else 0)
            pass_k_list.append(1 if mean_reward != 0 else 0)
            all_true_sample_cnt.append(1 if mean_reward == 1 else 0)
            useful_sample_cnt.append(1 if std_reward != 0 else 0)
            # Calculate response similarity
            tmp_response_similarity_list_top1000 = []
            tmp_response_similarity_list_top5000 = []
            tmp_response_similarity_list = []
            tmp_response_list = [sample_data.answer_str for sample_data in query_data.sample_data_list]
            for i in range(0, len(tmp_response_list) - 1):
                for j in range(i + 1, len(tmp_response_list)):
                    tmp_response_similarity_list_top1000.append(
                        jaccard_similarity(tmp_response_list[i][:1000], tmp_response_list[j][:1000])
                    )
                    tmp_response_similarity_list_top5000.append(
                        jaccard_similarity(tmp_response_list[i][:5000], tmp_response_list[j][:5000])
                    )
                    tmp_response_similarity_list.append(jaccard_similarity(tmp_response_list[i], tmp_response_list[j]))
            response_similarity_list_top1000.append(
                sum(tmp_response_similarity_list_top1000) / len(tmp_response_similarity_list_top1000)
                if tmp_response_similarity_list_top1000
                else 0
            )
            response_similarity_list_top5000.append(
                sum(tmp_response_similarity_list_top5000) / len(tmp_response_similarity_list_top5000)
                if tmp_response_similarity_list_top5000
                else 0
            )
            response_similarity_list.append(
                sum(tmp_response_similarity_list) / len(tmp_response_similarity_list)
                if tmp_response_similarity_list
                else 0
            )
            for sample_data in query_data.sample_data_list:
                if not sample_data.use_prm:
                    mean_adv_list.append(sample_data.advantage)
                else:
                    if sample_data.step_advantage_list:
                        mean_adv_list.extend(sample_data.step_advantage_list)
                # Count non-padding tokens in answer
                answer_input_ids = sample_data.answer_input_ids
                valid_token_num = 0
                for token_id in answer_input_ids:
                    if token_id != self.tokenizer.pad_token_id:
                        valid_token_num += 1
                answer_valid_token_length_list.append(valid_token_num)
                # Count sequences that reach max token length
                if (
                    answer_input_ids[-1] != self.tokenizer.eos_token_id
                    and answer_input_ids[-1] != self.tokenizer.pad_token_id
                ):
                    max_len_list.append(1)
                else:
                    max_len_list.append(0)
                # Count step numbers
                step_num_list.append(
                    len(sample_data.prm_processed_step_index_list)
                ) if sample_data.prm_processed_step_index_list else 0
                # Calculate ngram statistics - not all strategies have been decoded, decode if needed
                answer_str = (
                    sample_data.answer_str
                    if sample_data.answer_str
                    else self.tokenizer.decode(
                        sample_data.answer_input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
                    )
                )
                if len(answer_str) <= div_n_gram:
                    div_ratio_list.append(1.0)
                else:
                    ngram_list = []
                    for start in range(len(answer_str) - div_n_gram + 1):
                        ngram_list.append(answer_str[start : start + div_n_gram])
                    div_ratio = 1.0 * len(set(ngram_list)) / len(ngram_list)
                    div_ratio_list.append(div_ratio)
        stats = {
            'output_len': sum(answer_valid_token_length_list) / len(answer_valid_token_length_list)
            if answer_valid_token_length_list
            else 0,
            'step_num': sum(step_num_list) / len(step_num_list) if step_num_list else 0,
            'div_ratio': sum(div_ratio_list) / len(div_ratio_list) if div_ratio_list else 0,
            'mean_reward': sum(mean_reward_list) / len(mean_reward_list) if mean_reward_list else 0,
            'std_reward': sum(std_reward_list) / len(std_reward_list) if std_reward_list else 0,
            'mean_adv': sum(mean_adv_list) / len(mean_adv_list) if mean_adv_list else 0,
            'pass_k': sum(pass_k_list) / len(std_reward_list) if pass_k_list else 0,
            'all_true_ratio': sum(all_true_sample_cnt) / len(all_true_sample_cnt) if all_true_sample_cnt else 0,
            'useful_sample_ratio': sum(useful_sample_cnt) / len(useful_sample_cnt) if useful_sample_cnt else 0,
            'max_len_ratio': sum(max_len_list) / len(max_len_list) if max_len_list else 0,
            'response_similarity_list_top1000': sum(response_similarity_list_top1000)
            / len(response_similarity_list_top1000)
            if response_similarity_list_top1000
            else 0,
            'response_similarity_list_top5000': sum(response_similarity_list_top5000)
            / len(response_similarity_list_top5000)
            if response_similarity_list_top5000
            else 0,
            'response_similarity_list': sum(response_similarity_list) / len(response_similarity_list)
            if response_similarity_list
            else 0,
        }
        if self.runtime.monitor_stats_dict.keys():
            for key, value in self.runtime.monitor_stats_dict.items():
                stats[key] = value
        return stats

    @timeline_decorator(label_name_list=['trainer'], annotation='save for recovery')
    def __save_for_recovery(self):
        self.runtime.save_to_file()
        self.roles.save_for_recovery(self.runtime.current_round)

    def __report_stats(self, stats, current_round):
        logger.info(f'round{current_round} train stats: {stats}')
        for k, v in stats.items():
            self.writer.add_scalar(f'GRPOTrainer/{k}', v, current_round)