
import importlib
from functools import partial
from packaging.version import Version
from typing import Iterable, Dict

import torch
from torch import nn
import torch.distributed
from megatron.core.optimizer import OptimizerConfig
from megatron.core import parallel_state as mpu
from megatron.core import ModelParallelConfig
from verl.utils.megatron_utils import get_model_config
from megatron.core.pipeline_parallel import get_forward_backward_func

from megatron.core.distributed import finalize_model_grads


from megatron.core.optimizer import DistributedOptimizer

from omegaconf import OmegaConf
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
from verl import DataProto
from verl.trainer.ppo.core_algos import compute_policy_loss, kl_penalty, agg_loss
from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits, masked_mean, broadcast_dict_tensor, split_dict_tensor_into_batches

__all__ = ['MegatronPPOActor']


class MegatronPPOActor(BasePPOActor):

    def __init__(self, config, model_config, megatron_config: ModelParallelConfig, actor_module: nn.ModuleList,
                 actor_optimizer: DistributedOptimizer, actor_optimizer_config: OptimizerConfig):

        super().__init__(config)
        self._validate_config(config)
        self.model_config = model_config
        self.megatron_config = megatron_config
        self.actor_module = actor_module
        self.actor_optimizer: DistributedOptimizer = actor_optimizer
        self.actor_optimizer_config = actor_optimizer_config

        self.optimizer_step_args = OmegaConf.create({
            'skip_grad': None,
            'overlap_dp_param_comm': False,
            'overlap_dp_grad_comm': False,
            'gradient_accumulation_steps': 1,
            'sequence_parallel': self.megatron_config.sequence_parallel,
            'DDP_impl': 'local',
            'layernorm_allreduce_bucket_threshold': 0,
            'pipeline_model_parallel_split_rank': None,
            'reduce_grads_use_alltoall': False
        })

        config = get_model_config(self.actor_module[0])
        print(config)
        config.finalize_model_grads_func = finalize_model_grads

    def _validate_config(self, config) -> None:

        assert config.get('ulysses_sequence_parallel_size', 1) == 1

    def compute_log_prob(self, data: DataProto) -> torch.Tensor:

        data.batch = data.batch.contiguous()

        def compute_logprobs_fn(output, data):
            response = data['responses']
            response_length = response.size(1)
            logits = output
            logits = logits[:, -response_length - 1:-1].contiguous()
            log_probs = vocab_parallel_log_probs_from_logits(logits, response)
            return {'log_probs': log_probs}


        recompute_old_log_prob = self.config.get('recompute_old_log_prob', True)

        if recompute_old_log_prob or 'old_log_probs' not in data.batch.keys():
            select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids']
            batch = data.select(batch_keys=select_keys).batch
            input_ids = batch['input_ids']
            batch_size = input_ids.size(0)
            response = batch['responses']
            response_length = response.size(1)
            with torch.no_grad():
                output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn)
                if mpu.is_pipeline_last_stage(ignore_virtual=True):

                    log_probs = torch.cat([o['log_probs'] for o in output], dim=0)
                    log_probs = log_probs.to(torch.float32)
                else:
                    log_probs = torch.empty(size=(batch_size, response_length),
                                            dtype=torch.float32,
                                            device=input_ids.device)


                torch.distributed.broadcast(tensor=log_probs,
                                            src=mpu.get_pipeline_model_parallel_last_rank(),
                                            group=mpu.get_pipeline_model_parallel_group(),
                                            async_op=False)


        torch.cuda.empty_cache()

        return log_probs

    def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:

        select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
        if self.config.use_kl_loss:
            select_keys.append('ref_log_prob')
        data = data.select(batch_keys=select_keys)
        return data.make_iterator(mini_batch_size=self.config.ppo_mini_batch_size,
                                  epochs=self.config.ppo_epochs,
                                  dataloader_kwargs={'shuffle': self.config.shuffle})

    def forward_backward_batch(self, data: DataProto, forward_only=False, post_process_fn=None):

        broadcast_dict_tensor(data.batch,
                              src=mpu.get_pipeline_model_parallel_last_rank(),
                              group=mpu.get_pipeline_model_parallel_group())

        data.batch['attention_mask'] = data.batch['attention_mask'].to(bool)

        if data.meta_info.get('micro_batch_size', None) is not None:
            batch_size = data.meta_info['micro_batch_size']
        else:
            batch_size = self.config.ppo_micro_batch_size_per_gpu
        batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size)

        input_shapes = compute_transformers_input_shapes(
            batches,
            meta_info={
                'sequence_parallel': self.megatron_config.sequence_parallel,
                'hidden_size': self.model_config.hidden_size
            })
        n_micro_batch = len(batches)
        seq_len = batches[0]['input_ids'].shape[1]

        forward_backward_func = get_forward_backward_func()

        def loss_func(output, data, meta_info):
            if forward_only:
                if post_process_fn is None:
                    return 1.0, {'logits': output}
                else:
                    return 1.0, post_process_fn(output, data)

            responses = data['responses']
            response_length = responses.size(1)
            attention_mask = data['attention_mask']
            response_mask = attention_mask[:, -response_length:]
            old_log_prob = data['old_log_probs']
            advantages = data['advantages']

            clip_ratio = meta_info['clip_ratio']
            clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio
            clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio
            clip_ratio_c = meta_info['clip_ratio_c']
            entropy_coeff = meta_info['entropy_coeff']
            loss_agg_mode = self.config.loss_agg_mode


            logits = output
            logits = logits[:, -response_length - 1:-1].contiguous()
            logits_back = logits.clone()
            log_prob = vocab_parallel_log_probs_from_logits(logits, responses)
            logits = logits_back
            pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(old_log_prob=old_log_prob,
                                                                                  log_prob=log_prob,
                                                                                  advantages=advantages,
                                                                                  response_mask=response_mask,
                                                                                  cliprange=clip_ratio,
                                                                                  cliprange_low=clip_ratio_low,
                                                                                  cliprange_high=clip_ratio_high,
                                                                                  clip_ratio_c=clip_ratio_c,
                                                                                  loss_agg_mode=loss_agg_mode)
            entropy = vocab_parallel_entropy(logits)
            entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
            policy_loss = pg_loss - entropy_loss * entropy_coeff

            metrics = {}
            if self.config.use_kl_loss:
                ref_log_prob = data['ref_log_prob']

                kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
                kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)

                policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
                metrics['actor/kl_loss'] = kl_loss.detach().item()
                metrics['actor/kl_coef'] = self.config.kl_loss_coef


            stats = {
                'actor/entropy_loss': entropy_loss.detach().item(),
                'actor/pg_loss': pg_loss.detach().item(),
                'actor/pg_clipfrac': pg_clipfrac.detach().item(),
                'actor/ppo_kl': ppo_kl.detach().item(),
                'actor/pg_clipfrac_lower': pg_clipfrac_lower.detach().item()
            }
            append_to_dict(stats, metrics)
            return policy_loss, stats

        def forward_step(batch_iter, model):
            batch = next(batch_iter)
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            position_ids = batch['position_ids']
            from verl.models.mcore import gptmodel_forward

            output = gptmodel_forward(model,
                                      input_ids,
                                      attention_mask,
                                      position_ids,
                                      sequence_parallel=self.megatron_config.sequence_parallel)
            if forward_only:
                meta_info = None
            else:
                clip_ratio_c = self.config.get('clip_ratio_c', 3.0)
                meta_info = {
                    'clip_ratio': self.config.clip_ratio,
                    'entropy_coeff': self.config.entropy_coeff,
                    'clip_ratio_c': clip_ratio_c
                }
            return output, partial(loss_func, data=batch, meta_info=meta_info)


        batch_generator = make_batch_generator(batches, vpp_size=len(self.actor_module))


        if mpu.get_pipeline_model_parallel_world_size() > 1:
            losses_reduced = forward_backward_func(
                forward_step_func=forward_step,
                data_iterator=batch_generator,
                model=self.actor_module,
                num_microbatches=n_micro_batch,
                seq_length=batch_size * seq_len,
                micro_batch_size=1,
                forward_only=forward_only,
            )
        else:
            losses_reduced = forward_backward_func(
                forward_step_func=forward_step,
                data_iterator=batch_generator,
                model=self.actor_module,
                num_microbatches=n_micro_batch,
                seq_length=batch_size * seq_len,
                micro_batch_size=1,
                forward_only=forward_only,
            )

        return losses_reduced

    def update_policy(self, dataloader: Iterable[DataProto]) -> Dict:

        metrics = {}
        for data in dataloader:

            self.actor_optimizer.zero_grad()

            for chunk in self.actor_module:

                chunk.zero_grad_buffer()

            metric_micro_batch = self.forward_backward_batch(data)
            for metric in metric_micro_batch:
                append_to_dict(metrics, metric)

            update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step()

            if update_successful:

                pass
            else:
                raise NotImplementedError


        torch.cuda.empty_cache()

        return metrics
