# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""

import os
import uuid
from copy import deepcopy
from pprint import pprint
from typing import Dict, Type

import numpy as np
import ray
import torch
from omegaconf import OmegaConf, open_dict
from torch.utils.data import RandomSampler, SequentialSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm

from verl import DataProto
# [수정] verl.utils.debug에서 marked_timer 임포트
from verl.utils.debug import marked_timer
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.base import Worker
from verl.single_controller.ray.base import create_colocated_worker_cls
# [수정] metric_utils 경로 및 reduce_metrics 경로 변경
from verl.trainer.ppo.metric_utils import (
    compute_data_metrics,
    compute_throughout_metrics,
    compute_timing_metrics,
)
from verl.utils.metric import reduce_metrics
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn, RLHFDatasetCurriculum
from verl.trainer.ppo.ray_trainer import (
    RayPPOTrainer, 
    AdvantageEstimator, 
    apply_kl_penalty, 
    compute_advantage, 
    compute_response_mask
)
# [수정] 새로운 보상 계산 함수 임포트
from verl.trainer.ppo.reward import compute_reward

from .util import compute_language_statics_metrics

WorkerType = Type[Worker]

# 기존 헬퍼 함수들은 변경 없이 그대로 사용 가능
def prompt_response_tensor_shift(prompt_tensor, response_tensor, shift, pad_token_id):
    """
    Shift the tensor by the given shift value.
    """
    if shift < 1:
        return prompt_tensor, response_tensor
    else:
        prompt_tensor_new = prompt_tensor.clone()
        response_tensor_new = response_tensor.clone()
        prompt_tensor_new[shift:] = prompt_tensor[:-shift]
        prompt_tensor_new[:shift] = pad_token_id
        response_tensor_new[shift:] = response_tensor[:-shift]
        response_tensor_new[:shift] = prompt_tensor[-shift:]
        return prompt_tensor_new, response_tensor_new

def shift_batch_tensor(batch, pad_token_id):
    """
    Shift the batch tensor by the given shift value.
    """
    for i in range(len(batch)):
        shift = batch[i].non_tensor_batch['extra_info']['prefix_length']
        if shift < 1:
            continue
        prompts = batch[i].batch['prompts']
        responses = batch[i].batch['responses']
        input_ids_prev = batch[i].batch['input_ids']
        attention_mask_prev = batch[i].batch['attention_mask']
        position_ids_prev = batch[i].batch['position_ids']

        prompts, responses = prompt_response_tensor_shift(prompts, responses, shift, pad_token_id)
        
        position_ids = position_ids_prev.clone()
        input_ids = input_ids_prev.clone()
        attention_mask = attention_mask_prev.clone()
        
        input_ids[shift:] = input_ids_prev[:-shift]
        input_ids[:shift] = pad_token_id
        position_ids[shift:] = position_ids_prev[:-shift]
        position_ids[:shift] = 0
        attention_mask[shift:] = attention_mask_prev[:-shift]
        attention_mask[:shift] = 0
        
        batch.batch['input_ids'][i] = input_ids
        batch.batch['attention_mask'][i] = attention_mask
        batch.batch['position_ids'][i] = position_ids
        batch.batch['responses'][i] = responses
        batch.batch['prompts'][i] = prompts
    
    return batch


class RayMUCLPPOTrainer(RayPPOTrainer):
    """
    Note that this trainer runs on the driver process on a single CPU/GPU node.
    """

    # [수정] _create_dataloader 메서드를 최신 버전에 맞게 오버라이드
    def _create_dataloader(self, train_dataset, val_dataset, collate_fn_arg, train_sampler):
        # 인자로 받은 train_dataset 등이 None일 경우에만 새로 생성 (기존 로직 유지)
        if train_dataset is None:
            self.train_dataset = RLHFDatasetCurriculum(parquet_files=self.config.data.train_files,
                                                    tokenizer=self.tokenizer,
                                                    processor=self.processor,
                                                    prompt_key=self.config.data.prompt_key,
                                                    image_key=self.config.data.get('image_key', 'images'),
                                                    max_prompt_length=self.config.data.max_prompt_length,
                                                    filter_prompts=True,
                                                    return_raw_chat=self.config.data.get('return_raw_chat', False),
                                                    truncation=self.config.data.get('truncation', 'error'),
                                                    filter_overlong_prompts=self.config.data.filter_overlong_prompts,
                                                    init_prefix_ratio=self.config.data.get('init_prefix_ratio', 0.8),
                                                    prefix_type=self.config.data.get('prefix_type', 'none'),
                                                    use_group_uid=self.config.data.get('use_group_uid', False),
                                                    no_think=self.config.data.get('no_think', False))
        else:
            self.train_dataset = train_dataset

        assert self.train_dataset.truncation == self.config.data.get(
            'truncation', 'error'
        ), f'dataset truncation {self.train_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}'

        if train_sampler is None:
            if self.config.data.shuffle:
                train_dataloader_generator = torch.Generator()
                train_dataloader_generator.manual_seed(self.config.data.get('seed', 1))
                sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
            else:
                sampler = SequentialSampler(data_source=self.train_dataset)
        else:
            sampler = train_sampler
        
        # collate_fn이 주어지지 않으면 기본값 사용
        effective_collate_fn = collate_fn_arg if collate_fn_arg is not None else collate_fn

        self.train_dataloader = StatefulDataLoader(dataset=self.train_dataset,
                                                   batch_size=self.config.data.train_batch_size,
                                                   num_workers=self.config.data.get("dataloader_num_workers", 8),
                                                   drop_last=True,
                                                   collate_fn=effective_collate_fn,
                                                   sampler=sampler)

        if val_dataset is None:
            self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
                                           tokenizer=self.tokenizer,
                                           processor=self.processor,
                                           prompt_key=self.config.data.prompt_key,
                                           image_key=self.config.data.get('image_key', 'images'),
                                           max_prompt_length=self.config.data.max_prompt_length,
                                           filter_prompts=True,
                                           return_raw_chat=self.config.data.get('return_raw_chat', False),
                                           truncation=self.config.data.get('truncation', 'error'),
                                           filter_overlong_prompts=self.config.data.filter_overlong_prompts)
        else:
            self.val_dataset = val_dataset

        assert self.val_dataset.truncation == self.config.data.get(
            'truncation', 'error'
        ), f'dataset truncation {self.val_dataset.truncation} must be the same as config {self.config.data.get("truncation", "error")}'
        
        val_batch_size = self.config.data.val_batch_size if self.config.data.val_batch_size is not None else len(self.val_dataset)

        self.val_dataloader = StatefulDataLoader(
            dataset=self.val_dataset,
            batch_size=val_batch_size,
            num_workers=self.config.data.get("dataloader_num_workers", 8),
            shuffle=False,
            drop_last=False,
            collate_fn=effective_collate_fn)

        assert len(self.train_dataloader) >= 1
        # [변경] 최신 버전에서는 val_dataloader의 배치가 1이 아닐 수 있음
        assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"

        print(f'Size of train dataloader: {len(self.train_dataloader)}')

        total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
        if hasattr(self.train_dataset, 'total_training_steps'):
            self.train_dataset.total_training_steps = total_training_steps

        if self.config.trainer.total_training_steps is not None:
            total_training_steps = self.config.trainer.total_training_steps

        self.total_training_steps = total_training_steps
        print(f'Total training steps: {self.total_training_steps}')

        OmegaConf.set_struct(self.config, True)
        with open_dict(self.config):
            self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
            self.config.critic.optim.total_training_steps = total_training_steps

    # [수정] fit 메서드를 최신 버전에 맞게 전체적으로 수정
    def fit(self):
        from verl.utils.tracking import Tracking
        from omegaconf import OmegaConf

        logger = Tracking(project_name=self.config.trainer.project_name,
                          experiment_name=self.config.trainer.experiment_name,
                          default_backend=self.config.trainer.logger,
                          config=OmegaConf.to_container(self.config, resolve=True))

        self.global_steps = 0
        self._load_checkpoint()

        if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
            print ("Performing validation before training...")
            val_metrics = self._validate()
            pprint(f'Initial validation metrics: {val_metrics}')
            logger.log(data=val_metrics, step=self.global_steps)
            if self.config.trainer.get('val_only', False):
                return

        progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
        self.global_steps += 1
        last_val_metrics = None
        self.max_steps_duration = 0 # ESI 체크포인팅을 위해 추가

        for epoch in range(self.config.trainer.total_epochs):
            print (f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}")
            for batch_dict in self.train_dataloader:
                metrics = {}
                timing_raw = {}

                batch: DataProto = DataProto.from_single_dict(batch_dict)
                
                # [추가] UID를 초기에 추가
                batch.non_tensor_batch["uid"] = np.array(
                    [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
                )

                # [변경] _get_gen_batch 헬퍼 함수 사용
                gen_batch = self._get_gen_batch(batch)
                
                # [추가] meta_info에 global_steps 추가
                gen_batch.meta_info["global_steps"] = self.global_steps
                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)

                is_last_step = self.global_steps >= self.total_training_steps
                
                # [변경] 타이머를 marked_timer로 변경
                with marked_timer('step', timing_raw):
                    with marked_timer('gen', timing_raw, color="red"):
                        gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

                    if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
                        with marked_timer('gen_max', timing_raw, color="purple"):
                            gen_baseline_batch = deepcopy(gen_batch)
                            gen_baseline_batch.meta_info['do_sample'] = False
                            gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)

                            batch_for_baseline = batch.union(gen_baseline_output)
                            reward_baseline_tensor = self.reward_fn(batch_for_baseline)
                            reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
                            batch.batch['reward_baselines'] = reward_baseline_tensor

                    batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
                    batch = batch.union(gen_batch_output)

                    # [유지] 커리큘럼 학습을 위한 사용자 정의 로직
                    batch = shift_batch_tensor(batch, self.tokenizer.pad_token_id)

                    batch.batch['response_mask'] = compute_response_mask(batch)
                    
                    if self.config.trainer.balance_batch:
                        self._balance_batch(batch, metrics=metrics)

                    batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()

                    with marked_timer('old_log_prob', timing_raw, color="blue"):
                        old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
                        batch = batch.union(old_log_prob)

                    if self.use_reference_policy:
                        with marked_timer('ref', timing_raw, color="olive"):
                            # [변경] self.ref_in_actor 조건 추가
                            if not self.ref_in_actor:
                                ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
                            else:
                                ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
                            batch = batch.union(ref_log_prob)

                    if self.use_critic:
                        with marked_timer('values', timing_raw, color="cyan"):
                            values = self.critic_wg.compute_values(batch)
                            batch = batch.union(values)

                    with marked_timer('adv', timing_raw, color="brown"):
                        if self.use_rm:
                            reward_tensor = self.rm_wg.compute_rm_score(batch)
                            batch = batch.union(reward_tensor)
                        
                        # [변경] 보상 계산 로직을 compute_reward 함수로 대체
                        reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
                        batch.batch['token_level_scores'] = reward_tensor

                        # [변경] KL 페널티 로직을 최신 버전에 맞게 수정
                        if self.config.algorithm.use_kl_in_reward:
                            batch, kl_metrics = apply_kl_penalty(
                                batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
                            )
                            metrics.update(kl_metrics)
                        else:
                            batch.batch['token_level_rewards'] = batch.batch['token_level_scores']

                        # [변경] compute_advantage 함수 시그니처 변경
                        batch = compute_advantage(batch,
                                                  adv_estimator=self.config.algorithm.adv_estimator,
                                                  gamma=self.config.algorithm.gamma,
                                                  lam=self.config.algorithm.lam,
                                                  num_repeat=self.config.actor_rollout_ref.rollout.n,
                                                  config=self.config.algorithm) # config 인자 추가

                    if self.use_critic:
                        with marked_timer('update_critic', timing_raw, color="pink"):
                            critic_output = self.critic_wg.update_critic(batch)
                        critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
                        metrics.update(critic_output_metrics)

                    if self.config.trainer.critic_warmup <= self.global_steps:
                        with marked_timer('update_actor', timing_raw, color="red"):
                            actor_output = self.actor_rollout_wg.update_actor(batch)
                        actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
                        metrics.update(actor_output_metrics)

                    if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
                        (is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
                        with marked_timer('testing', timing_raw, color="green"):
                            val_metrics: dict = self._validate()
                            if is_last_step:
                                last_val_metrics = val_metrics
                        metrics.update(val_metrics)

                    # [변경] 체크포인트 저장 조건에 ESI 만료 시간 고려 추가 (베이스 클래스 로직 따름)
                    from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
                    esi_close_to_expiration = should_save_ckpt_esi(
                        max_steps_duration=self.max_steps_duration,
                        redundant_time=self.config.trainer.esi_redundant_time,
                    )
                    if self.config.trainer.save_freq > 0 and (
                        is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
                    ):
                        with marked_timer('save_checkpoint', timing_raw, color="green"):
                            self._save_checkpoint()

                # [추가] 스텝 시간 기록
                steps_duration = timing_raw["step"]
                self.max_steps_duration = max(self.max_steps_duration, steps_duration)

                metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
                metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
                n_gpus = self.resource_pool_manager.get_n_gpus()
                metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
                
                # [유지] 사용자 정의 메트릭 계산
                metrics.update(compute_language_statics_metrics(batch=batch, tokenizer=self.tokenizer))
                metrics.update({'prompt_length/prefix_ratio': self.train_dataset.current_prefix_ratio})
                
                logger.log(data=metrics, step=self.global_steps)
                print(f"Current prefix ratio: {self.train_dataset.current_prefix_ratio}")

                if is_last_step:
                    pprint(f'Final validation metrics: {last_val_metrics}')
                    progress_bar.close()
                    return
                
                # [유지] 커리큘럼 학습 비율 업데이트
                self.train_dataset.update_ratio_length()

                progress_bar.update(1)
                self.global_steps += 1
