

import torch
import torch.distributed as dist

from torch import nn

from megatron.core import parallel_state as mpu
from megatron.core import DistributedDataParallel as LocalDDP
from megatron.core.transformer.module import Float16Module
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from verl.utils.megatron_utils import get_model, unwrap_model
from verl.utils.memory_buffer import (
    build_memory_buffer,
    build_memory_reference_from_module,
    get_weight_buffer_meta_from_module,
)


class AllGatherPPModel:

    def __init__(self, model_provider) -> None:

        self._pp_group = mpu.get_pipeline_model_parallel_group()
        self._pp_rank = mpu.get_pipeline_model_parallel_rank()
        self._pp_size = mpu.get_pipeline_model_parallel_world_size()
        self._vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
        self._model_chunk_size = self._vpp_size or 1


        self._pp_models = [None] * self.pp_size

        rank_list = list(range(self.pp_size))

        rank_list[self.pp_rank], rank_list[-1] = rank_list[-1], rank_list[self.pp_rank]
        self._this_rank_models = None


        self.memory_buffers = [None] * self.pp_size
        for cur_pp_rank in rank_list:
            print(
                f'create pp model', f'torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, '
                f'reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB')

            mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
            if cur_pp_rank != self.pp_rank:
                models = get_model(model_provider, wrap_with_ddp=False)
                models = nn.ModuleList(models)
                assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
                self.pp_models[cur_pp_rank] = models
            else:

                models = get_model(model_provider)
                assert len(models) == self._model_chunk_size, f"{len(models)} != {self._model_chunk_size}"
                self._this_rank_models = nn.ModuleList(models)
                self.pp_models[cur_pp_rank] = nn.ModuleList(unwrap_model(models, (torchDDP, LocalDDP)))

            self._build_param_buffer(cur_pp_rank)
            self._build_param_references(cur_pp_rank, maintain_weight=cur_pp_rank == self.pp_rank)


            if cur_pp_rank != self.pp_rank:
                for model in self.pp_models[cur_pp_rank]:
                    model.eval()
                self._offload_params_to_cpu(cur_pp_rank)

    def _build_param_buffer(self, pp_rank):

        model = self.pp_models[pp_rank]
        weight_buffer_meta = get_weight_buffer_meta_from_module(model)
        self.memory_buffers[pp_rank] = build_memory_buffer(weight_buffer_meta)

    def _build_param_references(self, pp_rank, maintain_weight=False):
        model = self.pp_models[pp_rank]
        build_memory_reference_from_module(model, self.memory_buffers[pp_rank], maintain_weight=maintain_weight)

    def _load_params_to_cuda(self, pp_rank, to_empty=False):
        assert pp_rank != self.pp_rank, f"unexpected to load current pp rank [{pp_rank}] back to cuda"
        for buffer in self.memory_buffers[pp_rank].values():
            if not to_empty:
                buffer.data = buffer.data.to(torch.cuda.current_device(), non_blocking=True)
            else:
                buffer.data = torch.empty_like(buffer.data, device='cuda')

        self._build_param_references(pp_rank)

    def _offload_params_to_cpu(self, pp_rank, to_empty=False):
        assert pp_rank != self.pp_rank, f"unexpected to offload current pp rank [{pp_rank}] to cpu"
        for buffer in self.memory_buffers[pp_rank].values():
            if not to_empty:

                buffer.data = buffer.data.to('cpu', non_blocking=True)
            else:
                buffer.data = torch.empty_like(buffer.data, device='cpu')
        self._build_param_references(pp_rank)

    def load_params_to_cuda(self, to_empty=False):

        for cur_pp_rank in range(self.pp_size):
            if cur_pp_rank != self.pp_rank:
                self._load_params_to_cuda(cur_pp_rank, to_empty=to_empty)

    def allgather_params(self):

        for cur_pp_rank in range(self.pp_size):
            global_src = dist.get_global_rank(group=self.pp_group, group_rank=cur_pp_rank)


            for memory_buffer in self.memory_buffers[cur_pp_rank].values():
                dist.broadcast(tensor=memory_buffer.data, src=global_src, group=self.pp_group, async_op=False)

    def forward(self, *inputs, **kwargs):
        try:
            prev_output = None
            for cur_chunk_rank in range(self._model_chunk_size):
                if self._vpp_size:
                    mpu.set_virtual_pipeline_model_parallel_rank(cur_chunk_rank)

                for cur_pp_rank in range(self.pp_size):
                    mpu.set_pipeline_model_parallel_rank(cur_pp_rank)
                    self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(prev_output)
                    ret = self.pp_models[cur_pp_rank][cur_chunk_rank](*inputs, **kwargs)
                    self.pp_models[cur_pp_rank][cur_chunk_rank].set_input_tensor(None)
                    prev_output = ret
        finally:
            if self._vpp_size:
                mpu.set_virtual_pipeline_model_parallel_rank(0)
            mpu.set_pipeline_model_parallel_rank(self.pp_rank)
        return ret

    def __call__(self, *inputs, **kwargs):
        return self.forward(*inputs, **kwargs)

    def eval(self):
        for model in self.pp_models[self.pp_rank]:
            model.eval()

    def train(self):
        for model in self.pp_models[self.pp_rank]:
            model.train()

    def offload_params_to_cpu(self, to_empty=False):

        for cur_pp_rank in range(self.pp_size):
            if cur_pp_rank != self.pp_rank:
                self._offload_params_to_cpu(cur_pp_rank, to_empty=to_empty)

    def get_all_params(self):

        params = []
        for pp_rank in range(self.pp_size):
            params.append([])
            for model_chunk_idx in range(len(self.pp_models[pp_rank])):
                params[pp_rank].append({})
                pp_model = self.pp_models[pp_rank][model_chunk_idx]
                pp_model = unwrap_model(pp_model, ((torchDDP, LocalDDP, Float16Module)))
                for name, param in pp_model.named_parameters():

                    if 'lora' in name:
                        continue
                    params[pp_rank][model_chunk_idx][name] = param

        return params

    def update_this_rank_models(self, new_models):
        self._this_rank_models = new_models
        self._pp_models[self.pp_rank] = unwrap_model(new_models, (torchDDP, LocalDDP))

    @property
    def this_rank_models(self):
        return self._this_rank_models

    @property
    def pp_size(self):
        return self._pp_size

    @property
    def pp_rank(self):
        return self._pp_rank

    @property
    def pp_group(self):
        return self._pp_group

    @property
    def pp_models(self):
        return self._pp_models




from .base import BaseShardingManager

import torch
from torch import nn
import torch.distributed
from torch.distributed import new_group

from verl import DataProto
from verl.utils.torch_functional import (broadcast_dict_tensor, allgather_dict_tensors)
import verl.utils.megatron.tensor_parallel as tp_utils
from verl.third_party.vllm import parallel_state as vllm_ps
from verl.third_party.vllm import LLM
from verl.utils.model import normalize_pp_vpp_params

_MICRO_DATA_PARALLEL_GROUP = None


class MegatronVLLMShardingManager(BaseShardingManager):

    def __init__(self, module: AllGatherPPModel, inference_engine: LLM, model_config, layer_name_mapping):
        self.module = module
        self.inference_engine = inference_engine
        self.model_config = model_config
        self.layer_name_mapping = layer_name_mapping


        global _MICRO_DATA_PARALLEL_GROUP
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        train_tensor_parallel_size = mpu.get_tensor_model_parallel_world_size()
        infer_tensor_parallel_size = vllm_ps.get_tensor_model_parallel_world_size()


        assert infer_tensor_parallel_size <= train_tensor_parallel_size, \
            'Not implemented for infer_tp > train_tp'
        assert train_tensor_parallel_size % infer_tensor_parallel_size == 0

        micro_dp_size = train_tensor_parallel_size // infer_tensor_parallel_size
        num_micro_dp_groups = world_size // micro_dp_size
        assert _MICRO_DATA_PARALLEL_GROUP is None, ("micro data parallel group is already initialized")
        for i in range(num_micro_dp_groups):
            ranks = range(i * micro_dp_size, (i + 1) * micro_dp_size)
            group = new_group(ranks=ranks)
            if rank in ranks:
                _MICRO_DATA_PARALLEL_GROUP = group

    def default_tp_concat_fn(self, name, param, infer_params, model_config):


        if self.layer_name_mapping.get("qkv_layer_name") in name:

            q_lst = []
            k_lst = []
            v_lst = []
            assert model_config.num_attention_heads % model_config.num_key_value_heads == 0
            num_q_per_kv = model_config.num_attention_heads // model_config.num_key_value_heads
            assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0
            kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2)
            split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
            for infer_param in infer_params:
                q, k, v = infer_param.split(split_size)
                q_lst.append(q)
                k_lst.append(k)
                v_lst.append(v)
            q = torch.cat(q_lst, dim=0)
            k = torch.cat(k_lst, dim=0)
            v = torch.cat(v_lst, dim=0)

            infer_params = torch.cat((q, k, v), dim=0)

        elif self.layer_name_mapping.get("gate_proj_layer_name") in name:

            gate_lst = []
            up_lst = []
            for infer_param in infer_params:
                gate, up = infer_param.chunk(2)
                gate_lst.append(gate)
                up_lst.append(up)
            gate = torch.cat(gate_lst, dim=0)
            up = torch.cat(up_lst, dim=0)
            infer_params = torch.cat((gate, up), dim=0)

        else:

            infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(param))

        return infer_params

    def _post_process_params(self, params):

        micro_dp_size = get_micro_data_parallel_world_size()
        micro_dp_group = get_micro_data_parallel_group()

        if micro_dp_size <= 1:
            return

        origin_params = {}
        for name in params.keys():
            param = params[name]
            if tp_utils.is_tensor_parallel_param(param):

                infer_params = [torch.empty_like(param) for _ in range(micro_dp_size)]
                torch.distributed.all_gather(infer_params, param, group=micro_dp_group)
                infer_params = self.default_tp_concat_fn(name, param, infer_params, self.model_config)

                params[name] = infer_params
            origin_params[name] = param

        return origin_params

    def __enter__(self):

        self.module.load_params_to_cuda()

        self.module.allgather_params()

        params = self.module.get_all_params()


        self.params = normalize_pp_vpp_params(params=params,
                                              num_hidden_layers=self.model_config.num_hidden_layers,
                                              layer_name='layers')
        self.origin_params = self._post_process_params(self.params)
        self.inference_engine.sync_model_weights(self.params, load_format='megatron')

    def __exit__(self, exc_type, exc_value, traceback):

        self.module.offload_params_to_cpu()


        if get_micro_data_parallel_world_size() > 1:
            for name in self.params.keys():
                self.params[name] = self.origin_params[name]


        self.inference_engine.offload_model_weights()

        self.module.train()


        torch.cuda.empty_cache()

    def preprocess_data(self, data: DataProto) -> DataProto:

        micro_dp_size = get_micro_data_parallel_world_size()
        micro_dp_rank = get_micro_data_parallel_rank()


        broadcast_dict_tensor(data.batch,
                              src=mpu.get_tensor_model_parallel_src_rank(),
                              group=mpu.get_tensor_model_parallel_group())

        if micro_dp_size > 1:
            local_prompts = data.chunk(chunks=micro_dp_size)
            data = local_prompts[micro_dp_rank]

        return data

    def postprocess_data(self, data: DataProto) -> DataProto:
        meta_info = data.meta_info

        micro_dp_size = get_micro_data_parallel_world_size()
        if micro_dp_size > 1:
            data.batch = allgather_dict_tensors(data.batch.contiguous(),
                                                size=get_micro_data_parallel_world_size(),
                                                group=get_micro_data_parallel_group(),
                                                dim=0)


        if meta_info.get('allgather_pp_output', True):
            data.batch = allgather_dict_tensors(data.batch.contiguous(),
                                                size=mpu.get_pipeline_model_parallel_world_size(),
                                                group=mpu.get_pipeline_model_parallel_group(),
                                                dim=0)
        return data




def get_micro_data_parallel_group():
    assert _MICRO_DATA_PARALLEL_GROUP is not None
    return _MICRO_DATA_PARALLEL_GROUP


def get_micro_data_parallel_world_size():
    return torch.distributed.get_world_size(group=get_micro_data_parallel_group())


def get_micro_data_parallel_rank():
    return torch.distributed.get_rank(group=get_micro_data_parallel_group())
