# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
from deepspeed.inference.config import DeepSpeedInferenceConfig
from deepspeed.module_inject.replace_policy import replace_policies
from deepspeed.module_inject.utils import policy_to_ds_container
from .engine import DeepSpeedEngine
from .utils import TLinear, get_inactive_params
from deepspeed.runtime.zero import GatheredParameters
import time
import gc
import math
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from torch import nn
from deepspeed.utils import logger

from deepspeed.ops.op_builder import InferenceBuilder

from deepspeed.module_inject.layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
try:
    import transformers
    OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
except:
    OPTLearnedPositionalEmbedding = None
inference_cuda_module = None


class DeepSpeedHybridEngine(DeepSpeedEngine):
    r"""DeepSpeed engine for training and inference."""
    inference_mp_group = None

    def __init__(self, args, model, **kwargs):

        super().__init__(args, model, **kwargs)

        # synch seed between all GPUs
        _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
        dist.broadcast(_rng_state, 0)
        get_accelerator().set_rng_state(_rng_state.cpu())

        self.Z3_enabled = (self._config.zero_config.stage == 3)
        self.gather_all_layers = self._config.hybrid_engine.pin_parameters

        # inference containers / fwds
        self._inference_containers = []
        self._orig_modules = []
        self._orig_fwds = []
        self.create_inference_module()

        # Performance stats
        self._t_start = None
        self._total_latency = 0
        self._iters = 0
        self._training_start_time = None
        self._generate_latency = 0
        self._training_latency = 0
        self._total_batch_size = None
        self._gather_latency = 0

        global inference_cuda_module
        if inference_cuda_module is None:
            builder = InferenceBuilder()
            inference_cuda_module = builder.load()

        self.is_lora_fused = False

    def convert_to_linear_transposed(self, model):

        def _replace_linear_layer(r_module, parent_type=None, prev_type=None):
            for name, child in r_module.named_children():
                if child.__class__ in [torch.nn.Linear] and \
                    (parent_type is torch.nn.ModuleList or prev_type is torch.nn.ModuleList):
                    setattr(r_module, name, TLinear(child, name))
                else:
                    _replace_linear_layer(child, type(r_module), prev_type=parent_type)
            return r_module

        _replace_linear_layer(model)

    def new_inference_container(self, orig_layer, policy_cls, layer_id):
        policy = policy_cls(orig_layer, inference=True)
        _container = policy_to_ds_container(
            policy=policy,
            config=DeepSpeedInferenceConfig(
                set_empty_params=True,
                dtype=torch.float16 if self._config.fp16_enabled else torch.float32,
                max_out_tokens=self._config.hybrid_engine.max_out_tokens,
                min_out_tokens=self._config.hybrid_engine.max_out_tokens,
                transposed_mode=True,
            ),
            model_config=self.module.config if hasattr(self.module, 'config') else None,
            layer_id=layer_id,
            child=orig_layer)

        if self.mpu is not None:
            if hasattr(self.mpu, 'get_model_parallel_world_size'):
                _container.set_tensor_parallel_config(self.mpu.get_model_parallel_world_size(),
                                                      self.mpu.get_model_parallel_group())
            else:
                _container.set_tensor_parallel_config(self.mpu.get_tensor_model_parallel_world_size(),
                                                      self.mpu.get_tensor_model_parallel_group())
        else:
            _container.set_tensor_parallel_config(self._config.hybrid_engine.inference_tp_size, self.mp_group)
        _container.initialize_tensors(enable_training=True)
        _container.create_ds_model_config()
        _container.create_module()
        _container.set_params_wo_copy(Z3_enabled=self.Z3_enabled)
        return _container

    def populate_all_inference_policies(self):
        self.inference_policies = {}
        for plcy in replace_policies:
            _ = plcy(None)
            if isinstance(plcy._orig_layer_class, list):
                for orig_layer_class in plcy._orig_layer_class:
                    self.inference_policies.update({orig_layer_class: (self.new_inference_container, plcy)})
            elif plcy._orig_layer_class is not None:
                self.inference_policies.update({plcy._orig_layer_class: (self.new_inference_container, plcy)})
        self.inference_policies.update({
            nn.Linear: (LinearLayer, ),
            nn.Embedding: (EmbeddingLayer, ),
            nn.LayerNorm: (Normalize, ),
            OPTLearnedPositionalEmbedding: (OPTEmbedding, )
        })

    def _fuse_lora(self, params, lora_params):
        maybe_has_lora_params = [p for p in params if len(p.shape) > 1]
        for lora_param, weight in zip(lora_params, maybe_has_lora_params):
            if len(lora_param) == 3:
                lora_right_weight, \
                lora_left_weight, \
                lora_scaling = lora_param
                weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())

    def fuse_lora_weight(self):
        for layer_id in range(len(self.layer_params)):
            self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])

    def _unfuse_lora(self, params, lora_params):
        maybe_has_lora_params = [p for p in params if len(p.shape) > 1]
        for lora_param, weight in zip(lora_params, maybe_has_lora_params):
            if len(lora_param) == 3:
                lora_right_weight, \
                lora_left_weight, \
                lora_scaling = lora_param
                weight.data -= lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())

    def unfuse_lora_weight(self):
        for layer_id in range(len(self.layer_params)):
            self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])

    def unfuse_lora_weight_non_pinned(self):
        for layer_id in range(len(self.layer_params)):
            non_active_params = get_inactive_params(self.layer_params[layer_id])
            non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id])
            non_active_params.extend(non_active_lora_params)

            with GatheredParameters(non_active_params):
                self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])

    def retake_inference_cache(self):
        if self._config.hybrid_engine.release_inference_cache:
            retake_success = inference_cuda_module.retake_workspace()

            if not retake_success:
                logger.warning("Unable to acquire workspace on first attempt, emptying cache and retrying.")
                gc.collect()
                get_accelerator().empty_cache()
                retake_success = inference_cuda_module.retake_workspace()

                if not retake_success:
                    raise RuntimeError("Unable to retake inference workspace.")

    def generate(self, *inputs, **kwargs):
        if self._total_batch_size is None:
            bsz = inputs[0].shape[0] if len(inputs) > 0 else \
                kwargs['input_ids'].shape[0]
            self._total_batch_size = bsz * dist.get_world_size()

        self._t0 = time.time()

        if self.Z3_enabled and self.gather_all_layers:
            if self._config.hybrid_engine.inference_tp_size > 1:
                non_tp_params = []
                for other_layer in self._other_layers:
                    non_tp_params.extend(list(other_layer.parameters()))

                partition_size = self._config.hybrid_engine.tp_gather_partition_size

                layer_groups = math.ceil(len(self.layer_params) / partition_size)
                for lg in range(layer_groups):
                    non_active_params = []
                    non_active_lora_params = []
                    for layer_id in range(lg * partition_size, min(len(self.layer_params), (lg + 1) * partition_size),
                                          1):
                        non_tp_params.extend(self.layer_params[layer_id][:4])
                        non_active_params.extend(get_inactive_params(self.layer_params[layer_id]))
                        non_active_params.extend(get_inactive_params(self.layer_lora_params[layer_id]))
                    with GatheredParameters(non_active_params):
                        for layer_id in range(lg * partition_size,
                                              min(len(self.layer_params), (lg + 1) * partition_size), 1):
                            if len(self.all_lora_params) > 0:
                                self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])

                            if self.mpu is not None:
                                self._inference_containers[layer_id].apply_tensor_parallelism(self.mp_replace,
                                                                                              reversed_dim=True)

                # TODO(cmikeh2) Evaluate if this can be deferred when release_inference_cache
                # is enabled.
                gc.collect()
                get_accelerator().empty_cache()

                self._gather_latency = time.time() - self._t0

                input_shape = inputs[0].shape if len(inputs) > 0 else \
                                kwargs['input_ids'].shape
                output = torch.zeros(
                    (input_shape[0] * self._config.hybrid_engine.inference_tp_size, ) + input_shape[1:],
                    dtype=inputs[0].dtype if len(inputs) > 0 else kwargs['input_ids'].dtype,
                    device=inputs[0].device if len(inputs) > 0 else kwargs['input_ids'].device)
                input_cont = inputs[0].contiguous() if len(inputs) > 0 else kwargs['input_ids'].contiguous()
                dist.all_gather_into_tensor(output, input_cont, group=self.mp_group)

                if len(inputs) > 0:
                    inputs = (output, )
                else:
                    kwargs['input_ids'] = output

                self.retake_inference_cache()

                non_active_params = get_inactive_params(non_tp_params)
                with GatheredParameters(non_active_params):
                    generate_ret_vals = self._generate(*inputs, **kwargs)

                for layer_id in range(len(self.layer_params)):
                    self._inference_containers[layer_id].release_memory()

                rank = dist.get_rank(group=self.mp_group)
                generate_ret_vals = generate_ret_vals[input_shape[0] * rank:input_shape[0] * (rank + 1)]

            else:
                non_active_layers = get_inactive_params(self.all_layers_params)
                non_active_lora_params = get_inactive_params(self.all_lora_params)
                non_active_layers.extend(non_active_lora_params)
                with GatheredParameters(non_active_layers):
                    self._gather_latency = time.time() - self._t0

                    if len(self.all_lora_params) > 0:
                        self.fuse_lora_weight()

                    self.retake_inference_cache()
                    generate_ret_vals = self._generate(*inputs, **kwargs)

                    if len(self.all_lora_params) > 0:
                        self.unfuse_lora_weight()
        else:
            if len(self.all_lora_params) > 0 and (not self.Z3_enabled):
                self.fuse_lora_weight()

            self.retake_inference_cache()
            generate_ret_vals = self._generate(*inputs, **kwargs)

            if len(self.all_lora_params) > 0:
                if (not self.Z3_enabled):
                    self.unfuse_lora_weight()
                else:
                    self.unfuse_lora_weight_non_pinned()
                self.is_lora_fused = False

        if self._config.hybrid_engine.release_inference_cache:
            inference_cuda_module.release_workspace()
            gc.collect()
            get_accelerator().empty_cache()

        self._generate_latency = time.time() - self._t0 - self._gather_latency

        return generate_ret_vals

    def create_inference_containers(self, module, layer_id=0):
        for name, child in module.named_children():
            if child.__class__ in self.inference_policies:
                if self.inference_policies[child.__class__][0] == self.new_inference_container:
                    self._inference_containers.append(self.inference_policies[child.__class__][0](
                        child, self.inference_policies[child.__class__][-1], layer_id))
                    self._orig_modules.append(child)
                    self._orig_fwds.append(child.forward)

                    self.layer_params.append(self._inference_containers[layer_id].get_all_params())

                    self.lora_params.append(self._inference_containers[layer_id].get_lora_params())
                    self.layer_lora_params.append([])
                    for lora_param in self.lora_params[layer_id]:
                        self.layer_lora_params[layer_id].extend(lora_param[:-1])
                        self.all_lora_params.extend(lora_param[:-1])

                    layer_id += 1
                else:
                    self._other_layers.append(self.inference_policies[child.__class__][0](
                        weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None))
                    self._orig_modules_others.append(child)
                    self._orig_fwds_others.append(child.forward)
            else:
                self.create_inference_containers(child, layer_id=layer_id)

    def create_inference_module(self):
        self.layer_params = []
        self.layer_lora_params = []
        self.lora_params = []
        self.all_lora_params = []

        self._other_layers = []
        self._orig_modules_others = []
        self._orig_fwds_others = []

        if self._config.hybrid_engine.inference_tp_size > 1:
            if self.mpu is None:
                global_rank = dist.get_rank()
                world_size = dist.get_world_size()
                mp_group_id = global_rank // self._config.hybrid_engine.inference_tp_size
                num_mp_groups = world_size // self._config.hybrid_engine.inference_tp_size
                for mp_group_id in range(num_mp_groups):
                    ranks = list(
                        range(mp_group_id * self._config.hybrid_engine.inference_tp_size, \
                            (mp_group_id + 1) * self._config.hybrid_engine.inference_tp_size, \
                            1)
                    )
                    mp_group = dist.new_group(ranks)
                    if global_rank in ranks:
                        # mp_group is used for broader collective
                        self.mp_group = mp_group

                        # mp_replace is used for container tensor slicing
                        from deepspeed.module_inject import ReplaceWithTensorSlicing
                        self.mp_replace = ReplaceWithTensorSlicing(
                            mp_group=self.mp_group,
                            mp_size=self._config.hybrid_engine.inference_tp_size,
                            out_dim=0,
                            in_dim=1)

            else:
                self.mp_group = self.mpu.get_model_parallel_group() if hasattr(self.mpu, 'get_model_parallel_group') else \
                    self.mpu.get_tensor_model_parallel_group()

                from deepspeed.module_inject import ReplaceWithTensorSlicing
                self.mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group,
                                                           mp_size=self._config.hybrid_engine.inference_tp_size,
                                                           out_dim=0,
                                                           in_dim=1)
        else:
            self.mp_group = None
            self.mp_replace = None
        self.populate_all_inference_policies()
        self.all_layers_params = list(self.module.parameters())
        self.create_inference_containers(self.module)

        if len(self._inference_containers) > 0:
            self._generate = self.module.generate
            self.module.generate = self.generate

        self._t0 = time.time()

    def _zero3_forward(self, layer_id):

        def run_forward(*inputs, **kwargs):
            non_active_params = get_inactive_params(self.layer_params[layer_id])
            non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id])
            non_active_params.extend(non_active_lora_params)

            with GatheredParameters(non_active_params):
                if len(self.all_lora_params) > 0:
                    # Use the is_lora_fused flag to prevent multiple fusion in Z3 with non-pinned memory
                    if not self.is_lora_fused:
                        self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
                    # Set the is_lora_fused to true when reaching the last layer
                    if layer_id == len(self.layer_params) - 1:
                        self.is_lora_fused = True
                return self._inference_containers[layer_id].module.forward(*inputs, **kwargs)

        return run_forward

    def eval(self):
        if self._t_start is not None:
            latency = time.time() - self._t_start
            self._total_latency = self._total_latency + latency
            self._iters = self._iters + 1
            if not dist.is_initialized() or dist.get_rank() == 0:
                others = latency - (self._generate_latency + self._training_latency)
                print(f'|E2E latency={(latency):.2f}s ' + \
                      f'|Gather latency={self._gather_latency:.2f}s ({(self._gather_latency / latency * 100):.2f}%) '
                      f'|Generate time={(self._generate_latency):.2f}s ({(self._generate_latency / latency * 100):.2f}%) ' + \
                      f'|Training time={(self._training_latency):.2f}s ({(self._training_latency / latency * 100):.2f}%) ' + \
                      f'|Others={others:.2f} ({(others / latency * 100):.2f}%)'
                      f'|CurSamplesPerSec={(1 / latency * self._total_batch_size):.2f} ' + \
                      f'|AvgSamplesPerSec={(1 / (self._total_latency / self._iters) * self._total_batch_size):.2f}')
            self._t_start = time.time()
        self._training_latency = 0
        super().eval()
        if len(self._inference_containers) > 0:
            for i, (orig_module, inference_container) in enumerate(zip(self._orig_modules,
                                                                       self._inference_containers)):
                if self.Z3_enabled and not self.gather_all_layers:
                    orig_module.forward = self._zero3_forward(i)
                else:
                    orig_module.forward = inference_container.module.forward

                inference_container.transform_for_inference()

            if not self.Z3_enabled or self.gather_all_layers:
                for orig_module, inference_layer in zip(self._orig_modules_others, self._other_layers):
                    orig_module.forward = inference_layer.forward
        if self.Z3_enabled:
            gc.collect()
            get_accelerator().empty_cache()
        if self._t_start is None:
            self._t_start = time.time()

    def train(self, mode=True):
        if mode and len(self._orig_modules) > 0:
            for inference_container, orig_module, orig_fwd in zip(self._inference_containers, self._orig_modules,
                                                                  self._orig_fwds):
                inference_container.transform_for_training()
                orig_module.forward = orig_fwd
            for orig_module, orig_fwd in zip(self._orig_modules_others, self._orig_fwds_others):
                orig_module.forward = orig_fwd
        super().train(mode)
        if mode:
            self._training_start_time = time.time()

    def step(self, lr_kwargs=None):
        super().step(lr_kwargs=lr_kwargs)

        if len(self._inference_containers) > 0:
            for inference_container in self._inference_containers:
                inference_container.reset_params()

        if self._training_start_time is not None:
            self._training_latency += (time.time() - self._training_start_time)
            self._training_start_time = time.time()
