import os
import time
from abc import ABC
from datetime import timedelta

import ray
import torch
from tqdm import tqdm

from openrlhf.datasets import ILRPromptDataset
from openrlhf.trainer.ppo_utils import AdaptiveKLController, FixedKLController
from openrlhf.trainer.ppo_utils.ilr_experience_maker import RemoteExperienceMaker
from openrlhf.trainer.ray.ilr_launcher import ILRRayActorGroup
from openrlhf.utils import blending_datasets, get_tokenizer
from openrlhf.utils.deepspeed import ILRDeepspeedStrategy
from openrlhf.utils.logging_utils import init_logger
from openrlhf.utils.remote_rm_utils import remote_rm_fn_ray
from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call
logger = init_logger(__name__)


@ray.remote
class ILRTrainer(ABC):
    """
    Trainer for Proximal Policy Optimization (PPO) / REINFORCE++ / GRPO / RLOO and their variants.
    Single Controller with Multiple ActorGroups
    """

    def __init__(
        self,
        pretrain_llm1: str,
        pretrain_llm2: str,
        strategy: ILRDeepspeedStrategy,
        actor_model_group: ILRRayActorGroup,
        critic_model_group: ILRRayActorGroup,
        reward_model_group: ILRRayActorGroup,
        reference_model_group: ILRRayActorGroup,
        vllm_engines_llm1=None,
        vllm_engines_llm2=None,
        prompt_max_len: int = 120,
        dataloader_pin_memory: bool = True,
        prompt_split: str = "train",
        eval_split: str = "test",
        **generate_kwargs,
    ) -> None:
        super().__init__()

        self.strategy = strategy
        self.args = strategy.args

        self.tokenizer_llm1 = get_tokenizer(pretrain_llm1, None, "left", strategy, use_fast=not self.args.disable_fast_tokenizer)
        self.tokenizer_llm2 = get_tokenizer(pretrain_llm2, None, "left", strategy, use_fast=not self.args.disable_fast_tokenizer)
        self.actor_model_group = actor_model_group
        self.critic_model_group = critic_model_group
        self.reward_model_group = reward_model_group
        self.reference_model_group = reference_model_group
        self.dataloader_pin_memory = dataloader_pin_memory
        self.vllm_engines_llm1 = vllm_engines_llm1
        self.vllm_engines_llm2 = vllm_engines_llm2

        self.prompt_split = prompt_split
        self.eval_split = eval_split

        self.prompt_max_len = prompt_max_len
        self.generate_kwargs = generate_kwargs

        self.max_epochs = self.args.max_epochs
        self.remote_rm_url = self.args.remote_rm_url
        self.init_kl_coef_llm1 = self.args.init_kl_coef_llm1
        self.init_kl_coef_llm2 = self.args.init_kl_coef_llm2
        self.kl_target = self.args.kl_target
        self.kl_horizon = self.args.kl_horizon

        self.freezing_actor_steps = getattr(self.args, "freezing_actor_steps", -1)

        if self.kl_target:
            self.kl_ctl = AdaptiveKLController(self.init_kl_coef, self.kl_target, self.kl_horizon)
        else:
            self.kl_ctl_llm1 = FixedKLController(self.init_kl_coef_llm1)
            self.kl_ctl_llm2 = FixedKLController(self.init_kl_coef_llm2)

        self.experience_maker = RemoteExperienceMaker(
            self.actor_model_group,
            self.critic_model_group,
            self.reward_model_group,
            self.reference_model_group,
            self.tokenizer_llm1,
            self.tokenizer_llm2,
            self.prompt_max_len,
            self.kl_ctl_llm1,
            self.kl_ctl_llm2,
            self.strategy,
            self.remote_rm_url,
            vllm_engines_llm1=self.vllm_engines_llm1,
            vllm_engines_llm2=self.vllm_engines_llm2,
            packing_samples=self.strategy.args.packing_samples,
        )

        self.prepare_datasets('llm1')
        self.prepare_datasets('llm2')

        # wandb/tensorboard setting
        self._wandb = None
        self._tensorboard_llm1 = None
        self._tensorboard_llm2 = None
        self.generated_samples_table = None
        if self.strategy.args.use_wandb:
            import wandb

            self._wandb = wandb
            if not wandb.api.api_key:
                wandb.login(key=self.strategy.args.use_wandb)
            wandb.init(
                entity=self.strategy.args.wandb_org,
                project=self.strategy.args.wandb_project,
                group=self.strategy.args.wandb_group,
                name=self.strategy.args.wandb_run_name,
                config=self.strategy.args.__dict__,
                reinit=True,
            )

            wandb.define_metric("train/global_step")
            wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
            wandb.define_metric("eval/epoch")
            wandb.define_metric("eval/*", step_metric="eval/epoch", step_sync=True)
            self.generated_samples_table = wandb.Table(columns=["global_step", "text", "reward"])

        # Initialize TensorBoard writer if wandb is not available
        if self.strategy.args.use_tensorboard_llm1 and self.strategy.args.use_tensorboard_llm2 and self._wandb is None:
            from torch.utils.tensorboard import SummaryWriter

            os.makedirs(self.strategy.args.use_tensorboard_llm1, exist_ok=True)
            log_dir = os.path.join(self.strategy.args.use_tensorboard_llm1, self.strategy.args.wandb_run_name)
            self._tensorboard_llm1 = SummaryWriter(log_dir=log_dir)

            os.makedirs(self.strategy.args.use_tensorboard_llm2, exist_ok=True)
            log_dir = os.path.join(self.strategy.args.use_tensorboard_llm2, self.strategy.args.wandb_run_name)
            self._tensorboard_llm2 = SummaryWriter(log_dir=log_dir)

    def fit(
        self,
    ) -> None:
        args = self.args

        # Load datasets
        num_rollouts_per_episodes_llm1 = len(self.prompts_dataloader_llm1)
        num_rollouts_per_episodes_llm2 = len(self.prompts_dataloader_llm2)
        assert num_rollouts_per_episodes_llm1 == num_rollouts_per_episodes_llm2
        num_rollouts_per_episodes = num_rollouts_per_episodes_llm1

        # get eval and save steps
        if args.eval_steps == -1:
            args.eval_steps = num_rollouts_per_episodes  # Evaluate once per epoch
        if args.save_steps == -1:
            args.save_steps = float("inf")  # do not save ckpt

        # broadcast init checkpoint to vllm
        ckpt_path_llm1 = os.path.join(args.ckpt_path_llm1, "_actor")
        ckpt_path_llm2 = os.path.join(args.ckpt_path_llm2, "_actor")
        if args.load_checkpoint and os.path.exists(ckpt_path_llm1) and not self.vllm_engines_llm1 is None and os.path.exists(ckpt_path_llm2) and not self.vllm_engines_llm2 is None:
            # vLLM wakeup when vllm_enable_sleep
            if self.strategy.args.vllm_enable_sleep:
                batch_vllm_engine_call(self.vllm_engines_llm1, "wake_up")
            ref = self.actor_model_group.async_run_method(method_name="broadcast_to_vllm", llm_mark='llm1')
            ray.get(ref)
            # vLLM offload when vllm_enable_sleep
            if self.strategy.args.vllm_enable_sleep:
                batch_vllm_engine_call(self.vllm_engines_llm1, "sleep")
            
            # vLLM wakeup when vllm_enable_sleep
            if self.strategy.args.vllm_enable_sleep:
                batch_vllm_engine_call(self.vllm_engines_llm2, "wake_up")
            ref = self.actor_model_group.async_run_method(method_name="broadcast_to_vllm", llm_mark='llm2')
            ray.get(ref)
            # vLLM offload when vllm_enable_sleep
            if self.strategy.args.vllm_enable_sleep:
                batch_vllm_engine_call(self.vllm_engines_llm2, "sleep")


        # Restore step and start_epoch
        consumed_samples = ray.get(self.actor_model_group.async_run_method(method_name="get_consumed_samples"))[0]
        steps = consumed_samples // args.rollout_batch_size + 1
        flag_start_step = consumed_samples // args.rollout_batch_size + 1
        start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes
        consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size)

        for episode in range(start_episode, args.num_episodes):
            self.prompts_dataloader_llm1.sampler.set_epoch(
                episode, consumed_samples=0 if episode > start_episode else consumed_samples
            )
            self.prompts_dataloader_llm2.sampler.set_epoch(
                episode, consumed_samples=0 if episode > start_episode else consumed_samples
            )
            pbar = tqdm(
                range(self.prompts_dataloader_llm1.__len__()),
                desc=f"Episode [{episode + 1}/{args.num_episodes}]",
                disable=False,
            )
            dataloader_iterator_llm2 = iter(self.prompts_dataloader_llm2)

            for _, rand_prompts_llm1, labels_llm1, levels_llm1 in self.prompts_dataloader_llm1:
                try:
                    _, rand_prompts_llm2, labels_llm2, levels_llm2 = next(dataloader_iterator_llm2)
                except StopIteration:
                    dataloader_iterator_llm2 = iter(self.prompts_dataloader_llm2)
                    _, rand_prompts_llm2, labels_llm2, levels_llm2 = next(dataloader_iterator_llm2)

                # Eval llms' ability
#                 if steps % args.save_steps == 1 or steps == flag_start_step:
                if steps == flag_start_step:
                    # vLLM wakeup when vllm_enable_sleep
                    if self.strategy.args.vllm_enable_sleep:
                        batch_vllm_engine_call(self.vllm_engines_llm1, "wake_up")
                    refs = self.actor_model_group.async_run_method(method_name="eval_ability", eval_file=args.eval_ability_file, llms=self.vllm_engines_llm1, tokenizer=self.tokenizer_llm1, llm_mark='llm1', step=steps)
                    ray.get(refs)
                    # vLLM offload when vllm_enable_sleep
                    if self.strategy.args.vllm_enable_sleep:
                        batch_vllm_engine_call(self.vllm_engines_llm1, "sleep")
                    
                    # vLLM wakeup when vllm_enable_sleep
                    if self.strategy.args.vllm_enable_sleep:
                        batch_vllm_engine_call(self.vllm_engines_llm2, "wake_up")
                    refs = self.actor_model_group.async_run_method(method_name="eval_ability", eval_file=args.eval_ability_file, llms=self.vllm_engines_llm2, tokenizer=self.tokenizer_llm2, llm_mark='llm2', step=steps)
                    ray.get(refs)
                    # vLLM offload when vllm_enable_sleep
                    if self.strategy.args.vllm_enable_sleep:
                        batch_vllm_engine_call(self.vllm_engines_llm2, "sleep")
                    


                experiences_llm1, experiences_llm2 = self.experience_maker.make_experience_list(rand_prompts_llm1, labels_llm1, levels_llm1, rand_prompts_llm2, labels_llm2, levels_llm2, **self.generate_kwargs)
                
                sample0_llm1 = self.tokenizer_llm1.batch_decode(
                    experiences_llm1[0].sequences[0].unsqueeze(0), skip_special_tokens=False
                )
                print(sample0_llm1)
                sample0_llm2 = self.tokenizer_llm2.batch_decode(
                    experiences_llm2[0].sequences[0].unsqueeze(0), skip_special_tokens=False
                )
                print(sample0_llm2)

                refs = self.actor_model_group.async_run_method_batch(method_name="append", experiences_llm1=experiences_llm1, experiences_llm2=experiences_llm2)
                if self.critic_model_group is not None:
                    refs.extend(
                        self.critic_model_group.async_run_method_batch(method_name="append", experience=experiences)
                    )
                ray.get(refs)
                
                if steps % 1 == 0:
                    status_llm1 = self.ppo_train(steps, 'llm1')
                    if "kl" in status_llm1:
                        self.kl_ctl_llm1.update(status_llm1["kl"], args.rollout_batch_size * args.n_samples_per_prompt)
                    pbar.set_postfix(status_llm1)
                    # Add generated samples to status dictionary
                    status_llm1["generated_samples"] = [sample0_llm1[0], experiences_llm1[0].info["reward"][0]]
                    # logs/checkpoints
                    client_states = {"consumed_samples": steps * args.rollout_batch_size}
                    self.save_logs_and_checkpoints(args, steps, pbar, status_llm1, client_states, 'llm1')
                
                if steps % 1 == 0:
                    status_llm2 = self.ppo_train(steps, 'llm2')
                    if "kl" in status_llm2:
                        self.kl_ctl_llm2.update(status_llm2["kl"], args.rollout_batch_size * args.n_samples_per_prompt)
                    pbar.set_postfix(status_llm2)
                    status_llm2["generated_samples"] = [sample0_llm2[0], experiences_llm2[0].info["reward"][0]]
                    client_states = {"consumed_samples": steps * args.rollout_batch_size}
                    self.save_logs_and_checkpoints(args, steps, pbar, status_llm2, client_states, 'llm2')
                

                pbar.update()
                steps = steps + 1

        if self._wandb is not None:
            self._wandb.finish()
        if self._tensorboard_llm1 is not None:
            self._tensorboard_llm1.close()
        if self._tensorboard_llm2 is not None:
            self._tensorboard_llm2.close()

    def ppo_train(self, global_steps, llm_mark):
        status = {}
        if llm_mark == 'llm1':
            kl_ctl = self.kl_ctl_llm1
            vllm_engines = self.vllm_engines_llm1
        elif llm_mark == 'llm2':
            kl_ctl = self.kl_ctl_llm2
            vllm_engines = self.vllm_engines_llm2

        # triger remote critic model training
        if self.critic_model_group is not None:
            # sync for deepspeed_enable_sleep
            if self.strategy.args.deepspeed_enable_sleep:
                ray.get(self.critic_model_group.async_run_method(method_name="reload_states"))

            critic_status_ref = self.critic_model_group.async_run_method(method_name="fit")

            if self.strategy.args.colocate_all_models or self.strategy.args.deepspeed_enable_sleep:
                status.update(ray.get(critic_status_ref)[0])
            if self.strategy.args.deepspeed_enable_sleep:
                ray.get(self.critic_model_group.async_run_method(method_name="offload_states"))

        # actor model training
        if global_steps > self.freezing_actor_steps:
            if self.strategy.args.deepspeed_enable_sleep:
                self.actor_model_group.async_run_method(method_name="reload_states", llm_mark=llm_mark)

            actor_status_ref = self.actor_model_group.async_run_method(method_name="fit", kl_ctl=kl_ctl.value, llm_mark=llm_mark)
            status.update(ray.get(actor_status_ref)[0])

            if self.strategy.args.deepspeed_enable_sleep:
                self.actor_model_group.async_run_method(method_name="offload_states", llm_mark=llm_mark)

            # 4. broadcast weights to vllm engines
            if vllm_engines is not None:
                if self.strategy.args.vllm_enable_sleep:
                    from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call

                    batch_vllm_engine_call(vllm_engines, "wake_up")

                ray.get(self.actor_model_group.async_run_method(method_name="broadcast_to_vllm", llm_mark=llm_mark))

                if self.strategy.args.vllm_enable_sleep:
                    batch_vllm_engine_call(vllm_engines, "sleep")

        # 5. wait remote critic model training done
        if self.critic_model_group and not self.strategy.args.colocate_all_models:
            status.update(ray.get(critic_status_ref)[0])

        return status

    def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}, llm_mark=None):
        if llm_mark == 'llm1':
            now_tensorboard = self._tensorboard_llm1
        elif llm_mark == 'llm2':
            now_tensorboard = self._tensorboard_llm2

        if global_step % args.logging_steps == 0:
            # wandb
            if self._wandb is not None:
                # Add generated samples to wandb using Table
                if "generated_samples" in logs_dict:
                    # https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737
                    new_table = self._wandb.Table(
                        columns=self.generated_samples_table.columns, data=self.generated_samples_table.data
                    )
                    new_table.add_data(global_step, *logs_dict.pop("generated_samples"))
                    self.generated_samples_table = new_table
                    self._wandb.log({"train/generated_samples": new_table})
                logs = {
                    "train/%s" % k: v
                    for k, v in {
                        **logs_dict,
                        "global_step": global_step,
                    }.items()
                }
                self._wandb.log(logs)
            # TensorBoard
            elif now_tensorboard is not None:
                for k, v in logs_dict.items():
                    if k == "generated_samples":
                        # Record generated samples in TensorBoard using simple text format
                        text, reward = v
                        formatted_text = f"Sample:\n{text}\n\nReward: {reward:.4f}"
                        now_tensorboard.add_text("train/generated_samples", formatted_text, global_step)
                    else:
                        now_tensorboard.add_scalar(f"train/{k}", v, global_step)

        # TODO: Add evaluation mechanism for PPO
        if global_step % args.eval_steps == 0 and self.eval_dataloader and len(self.eval_dataloader) > 0:
            self.evaluate(self.eval_dataloader, global_step, args.eval_temperature, args.eval_n_samples_per_prompt)
        # save ckpt
        # TODO: save best model on dev, use loss/perplexity/others on whole dev dataset as metric
        if global_step % args.save_steps == 0:
            tag = f"global_step{global_step}"
            ref = self.actor_model_group.async_run_method(
                method_name="save_checkpoint", tag=tag, client_states=client_states, llm_mark=llm_mark
            )
            if self.critic_model_group is not None:
                ref.extend(self.critic_model_group.async_run_method(method_name="save_checkpoint", tag=tag))
            ray.get(ref)

    def evaluate(self, eval_dataloader, global_step, temperature=0.6, n_samples_per_prompt=1):
        """Evaluate model performance on eval dataset.

        Args:
            eval_dataloader: DataLoader containing evaluation prompts, labels and data sources
            global_step: Current training step for logging
            n_samples_per_prompt: Number of samples to generate per prompt for pass@k calculation
        """
        start_time = time.time()
        logger.info(f"⏰ Evaluation start time: {time.strftime('%Y-%m-%d %H:%M:%S')}")

        # vLLM wakeup when vllm_enable_sleep
        if self.strategy.args.vllm_enable_sleep:
            from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call

            batch_vllm_engine_call(self.vllm_engines, "wake_up")

        with torch.no_grad():
            # First collect all prompts and labels
            all_prompts = []
            all_labels = []
            all_datasources = []

            for datasources, prompts, labels in eval_dataloader:
                all_prompts.extend(prompts)
                all_labels.extend(labels)
                all_datasources.extend(datasources)

            # Generate samples and calculate rewards
            generate_kwargs = self.generate_kwargs.copy()
            generate_kwargs["temperature"] = temperature
            generate_kwargs["n_samples_per_prompt"] = n_samples_per_prompt
            samples_list = self.experience_maker.generate_samples(all_prompts, all_labels, **generate_kwargs)
            queries_list = sum(
                [self.tokenizer.batch_decode(s.sequences, skip_special_tokens=False) for s in samples_list], []
            )

            # duplicate prompts and labels for each sample
            all_prompts = sum([[prompt] * n_samples_per_prompt for prompt in all_prompts], [])
            all_labels = sum([[label] * n_samples_per_prompt for label in all_labels], [])

            # Calculate rewards
            if self.experience_maker.custom_reward_func:
                # Let Ray automatically distribute the workload across available resources
                batch_size = self.strategy.args.micro_rollout_batch_size
                num_chunks = (len(queries_list) + batch_size - 1) // batch_size
                r_refs = []
                for i in range(num_chunks):
                    start_idx = i * batch_size
                    end_idx = min((i + 1) * batch_size, len(queries_list))
                    r = self.experience_maker.custom_reward_func.remote(
                        queries_list[start_idx:end_idx],
                        all_prompts[start_idx:end_idx],
                        all_labels[start_idx:end_idx],
                    )
                    r_refs.append(r)
            else:
                # Distribute data across different remote reward function servers
                num_servers = len(self.remote_rm_url)
                batch_size = (len(queries_list) + num_servers - 1) // num_servers
                r_refs = []
                for i in range(num_servers):
                    start_idx = i * batch_size
                    end_idx = min((i + 1) * batch_size, len(queries_list))
                    rm = self.remote_rm_url[i]
                    r = remote_rm_fn_ray.remote(
                        rm,
                        queries=queries_list[start_idx:end_idx],
                        prompts=all_prompts[start_idx:end_idx],
                        labels=all_labels[start_idx:end_idx],
                    )
                    r_refs.append(r)

            # Reshape rewards to (num_prompts, n_samples_per_prompt)
            rewards = ray.get(r_refs)
            rewards = torch.cat(rewards, dim=0).reshape(-1, n_samples_per_prompt)

            # Collect local statistics for each data source
            global_metrics = {}  # {datasource: {"pass{n_samples_per_prompt}": 0, "pass1": 0, "count": 0}}

            for i, datasource in enumerate(all_datasources):
                if datasource not in global_metrics:
                    global_metrics[datasource] = {f"pass{n_samples_per_prompt}": 0, "pass1": 0, "count": 0}

                # Calculate pass@k and pass@1
                prompt_rewards = rewards[i]
                if n_samples_per_prompt > 1:
                    global_metrics[datasource][f"pass{n_samples_per_prompt}"] += prompt_rewards.max().float().item()
                global_metrics[datasource]["pass1"] += prompt_rewards.mean().float().item()
                global_metrics[datasource]["count"] += 1

            # Calculate global averages
            logs = {}
            for datasource, metrics in global_metrics.items():
                logs[f"eval_{datasource}_pass{n_samples_per_prompt}"] = (
                    metrics[f"pass{n_samples_per_prompt}"] / metrics["count"]
                )
                logs[f"eval_{datasource}_pass1"] = metrics["pass1"] / metrics["count"]

            # Log to wandb/tensorboard
            if self._wandb is not None:
                logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": global_step}.items()}
                self._wandb.log(logs)
            elif self._tensorboard is not None:
                for k, v in logs.items():
                    self._tensorboard.add_scalar(f"eval/{k}", v, global_step)

        if self.strategy.args.vllm_enable_sleep:
            batch_vllm_engine_call(self.vllm_engines, "sleep")

        end_time = time.time()
        duration = end_time - start_time
        time_str = str(timedelta(seconds=duration)).split(".")[0]
        logger.info(f"✨ Evaluation completed in {time_str}")

    def prepare_datasets(self, llm_mark):
        if llm_mark == 'llm1':
            tokenizer = self.tokenizer_llm1
        elif llm_mark == 'llm2':
            tokenizer = self.tokenizer_llm2


        args = self.args
        strategy = self.strategy

        # prepare datasets
        train_data = blending_datasets(
            args.prompt_data,
            args.prompt_data_probs,
            strategy,
            args.seed,
            max_count=args.max_samples,
            dataset_split=self.prompt_split,
        )

        # Create train dataset
        train_data = train_data.select(range(min(args.max_samples, len(train_data))))
        prompts_dataset = ILRPromptDataset(train_data, tokenizer, strategy, input_template=args.input_template)
        prompts_dataloader = strategy.setup_dataloader(
            prompts_dataset,
            args.rollout_batch_size,
            True,
            True,
        )

        # Create eval dataset if eval data exists
        if getattr(args, "eval_dataset", None):
            eval_data = blending_datasets(
                args.eval_dataset,
                None,  # No probability sampling for eval datasets
                strategy,
                dataset_split=self.eval_split,
            )
            eval_data = eval_data.select(range(min(args.max_samples, len(eval_data))))
            eval_dataset = PromptDataset(eval_data, self.tokenizer, strategy, input_template=args.input_template)
            eval_dataloader = strategy.setup_dataloader(eval_dataset, 1, True, False)
        else:
            eval_dataloader = None

        if llm_mark == 'llm1':
            self.prompts_dataloader_llm1 = prompts_dataloader
        elif llm_mark == 'llm2':
            self.prompts_dataloader_llm2 = prompts_dataloader
        
        self.eval_dataloader = eval_dataloader
        self.max_steps = (
            len(prompts_dataset)
            * args.n_samples_per_prompt
            // args.train_batch_size
            * args.num_episodes
            * args.max_epochs
        )

    def get_max_steps(self):
        return self.max_steps
