              
                                                      
                                               

from dataclasses import asdict
from typing import Callable, List, Optional
from typing_extensions import override
from copy import deepcopy

import torch
from torch import nn, Tensor

import megatron.core.parallel_state as mpu

from megatron.core.transformer.custom_layers.transformer_engine import (
    TEColumnParallelGroupedLinear,
    TERowParallelGroupedLinear,
    TELayerNormColumnParallelLinear,
    TERowParallelLinear,
    TEColumnParallelLinear,
)
from megatron.core.tensor_parallel.mappings import (
    gather_from_tensor_model_parallel_region,
    scatter_to_tensor_model_parallel_region,
    scatter_to_sequence_parallel_region,
)
from megatron.core.tensor_parallel.layers import (
    RowParallelLinear,
    ColumnParallelLinear,
)
from megatron.core.utils import divide

from gpatch.core.transformer.transformer_config import GpatchTransformerConfig


class TEColumnParallelLoRALinear(TEColumnParallelLinear):
    """
    Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
    layernorm and linear layers
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        *,
        config: GpatchTransformerConfig,
        init_method: Callable,
        gather_output: bool,
        bias: bool,
        skip_bias_add: bool,
        is_expert: bool = False,
        skip_weight_param_allocation: bool = False,
        tp_comm_buffer_name: str = None,
    ):
        super(TEColumnParallelLoRALinear, self).__init__(
            input_size=input_size,
            output_size=output_size,
            config=config,
            init_method=init_method,
            gather_output=gather_output,
            bias=bias,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name,
        )
        self.lora_r = config.lora_r
        self.lora_alpha = config.lora_alpha
        self.lora_scaling = self.lora_alpha / self.lora_r

        self.lora_a = TEColumnParallelLinear(
            input_size=input_size,
            output_size=self.lora_r,
            config=config,
            init_method=init_method,
            gather_output=False,
            bias=False,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_cola' if tp_comm_buffer_name is not None else None,
        )
        self.lora_b = TEColumnParallelLinear(
            input_size=self.lora_r,
            output_size=output_size,
            config=config,
            init_method=init_method,
            gather_output=gather_output,
            bias=False,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_colb' if tp_comm_buffer_name is not None else None,
        )

    @override
    def forward(self, x):
        output, bias = super().forward(x)
        lora_output, _ = self.lora_a(x)
        lora_output = gather_from_tensor_model_parallel_region(lora_output)
        if self.config.sequence_parallel:
            lora_output = scatter_to_sequence_parallel_region(lora_output)
        lora_output, _ = self.lora_b(lora_output)

        output = output + lora_output * self.lora_scaling
        return output, bias

    def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
        for name, module in self._modules.items():
            sub_prefix = f'{prefix}{name}.'
            if not isinstance(module, nn.Dropout):
                sub_sd = module.sharded_state_dict(sub_prefix, sharded_offsets, metadata)
                sharded_state_dict.update(sub_sd)
        return sharded_state_dict


class TERowParallelLoRALinear(TERowParallelLinear):
    """
    Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
    to megatron's `RowParallelLinear` layer.
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        *,
        config: GpatchTransformerConfig,
        init_method: Callable,
        bias: bool,
        input_is_parallel: bool,
        skip_bias_add: bool,
        is_expert: bool = False,
        tp_comm_buffer_name: str = None,
    ):
        super(TERowParallelLoRALinear, self).__init__(
            input_size=input_size,
            output_size=output_size,
            config=config,
            init_method=init_method,
            bias=bias,
            input_is_parallel=input_is_parallel,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            tp_comm_buffer_name=tp_comm_buffer_name,
        )
        self.lora_r = config.lora_r
        self.lora_alpha = config.lora_alpha
        self.lora_scaling = self.lora_alpha / self.lora_r
        self.input_is_parallel = input_is_parallel

        lora_a_config = deepcopy(config)
        lora_a_config.sequence_parallel = False
        self.lora_a = TERowParallelLinear(
            input_size=input_size,
            output_size=self.lora_r,
            config=lora_a_config,
            init_method=init_method,
            bias=False,
            input_is_parallel=input_is_parallel,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_rowa' if tp_comm_buffer_name is not None else None,
        )
        self.lora_b = TERowParallelLinear(
            input_size=self.lora_r,
            output_size=output_size,
            config=config,
            init_method=init_method,
            bias=False,
            input_is_parallel=True,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_rowb' if tp_comm_buffer_name is not None else None,
        )

    @override
    def forward(self, x):
        output, bias = super().forward(x)
        lora_output, _ = self.lora_a(x)
        lora_output = scatter_to_tensor_model_parallel_region(lora_output)
        lora_output, _ = self.lora_b(lora_output)

        output = output + lora_output * self.lora_scaling
        return output, bias

    def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
        for name, module in self._modules.items():
            sub_prefix = f'{prefix}{name}.'
            if not isinstance(module, nn.Dropout):
                sub_sd = module.sharded_state_dict(sub_prefix, sharded_offsets, metadata)
                sharded_state_dict.update(sub_sd)
        return sharded_state_dict


class TELayerNormColumnParallelLoRALinear(TELayerNormColumnParallelLinear):

    def __init__(
        self,
        input_size: int,
        output_size: int,
        *,
        config: GpatchTransformerConfig,
        init_method: Callable,
        gather_output: bool,
        bias: bool,
        skip_bias_add: bool,
        is_expert: bool = False,
        skip_weight_param_allocation: bool = False,
        tp_comm_buffer_name: str = None,
    ):
        super(TELayerNormColumnParallelLoRALinear, self).__init__(
            input_size=input_size,
            output_size=output_size,
            config=config,
            init_method=init_method,
            gather_output=gather_output,
            bias=bias,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name,
        )
        self.return_layernorm_output = True

        self.lora_r = config.lora_r
        self.lora_alpha = config.lora_alpha
        self.lora_scaling = self.lora_alpha / self.lora_r

        self.lora_a = TEColumnParallelLinear(
            input_size=input_size,
            output_size=self.lora_r,
            config=config,
            init_method=init_method,
            gather_output=False,
            bias=False,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_norma' if tp_comm_buffer_name is not None else None,
        )
        self.lora_b = TEColumnParallelLinear(
            input_size=self.lora_r,
            output_size=output_size,
            config=config,
            init_method=init_method,
            gather_output=gather_output,
            bias=False,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_normb' if tp_comm_buffer_name is not None else None,
        )

    @override
    def forward(self, x):
        if self.return_bias:
            output, bias, norm = super().forward(x)
        else:
            (output, norm), bias = super().forward(x)
        lora_output, _ = self.lora_a(norm)
        lora_output = gather_from_tensor_model_parallel_region(lora_output)
        if self.config.sequence_parallel:
            lora_output = scatter_to_sequence_parallel_region(lora_output)
        lora_output, _ = self.lora_b(lora_output)

        output = output + lora_output * self.lora_scaling
        return output, bias

    def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
        for name, module in self._modules.items():
            sub_prefix = f'{prefix}{name}.'
            if not isinstance(module, nn.Dropout):
                sub_sd = module.sharded_state_dict(sub_prefix, sharded_offsets, metadata)
                sharded_state_dict.update(sub_sd)
        return sharded_state_dict


class TELayerNormColumnParallelLoRAMergeLinear(TELayerNormColumnParallelLinear):

    def __init__(
        self,
        input_size: int,
        output_size: int,
        *,
        config: GpatchTransformerConfig,
        init_method: Callable,
        gather_output: bool,
        bias: bool,
        skip_bias_add: bool,
        is_expert: bool = False,
        skip_weight_param_allocation: bool = False,
        tp_comm_buffer_name: str = None,
    ):
        assert config.gated_linear_unit, "use class TELayerNormColumnParallelLoRALinear when gated_linear_unit is False"
        super(TELayerNormColumnParallelLoRAMergeLinear, self).__init__(
            input_size=input_size,
            output_size=output_size,
            config=config,
            init_method=init_method,
            gather_output=gather_output,
            bias=bias,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name,
        )
        self.return_layernorm_output = True

        self.lora_r = config.lora_r
        self.lora_alpha = config.lora_alpha
        self.lora_scaling = self.lora_alpha / self.lora_r
        self.output_size_list = [output_size // 2 for _ in range(2)]
        assert sum(
            self.output_size_list) == output_size, f'output_size_list must be equal to output_size'
        self.tp_size = mpu.get_tensor_model_parallel_world_size()
        self.output_split_list = [ele // self.tp_size for ele in self.output_size_list]

        self.lora_a = TEColumnParallelLinear(
            input_size=input_size,
            output_size=self.lora_r * len(self.output_size_list),
            config=config,
            init_method=init_method,
            gather_output=False,
            bias=False,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_norma' if tp_comm_buffer_name is not None else None,
        )
        self.lora_b = torch.nn.ModuleList()
        for lora_output_size in self.output_size_list:
            self.lora_b.append(
                TEColumnParallelLinear(
                    input_size=self.lora_r,
                    output_size=lora_output_size,
                    config=config,
                    init_method=init_method,
                    gather_output=gather_output,
                    bias=False,
                    skip_bias_add=skip_bias_add,
                    is_expert=is_expert,
                    skip_weight_param_allocation=skip_weight_param_allocation,
                    tp_comm_buffer_name=tp_comm_buffer_name +
                    '_normb' if tp_comm_buffer_name is not None else None,
                ))

    @override
    def forward(self, x):
        if self.return_bias:
            output, bias, norm = super().forward(x)
        else:
            (output, norm), bias = super().forward(x)
        lora_output, _ = self.lora_a(norm)
        lora_output = gather_from_tensor_model_parallel_region(lora_output)

        lora_b_inputs = torch.chunk(lora_output, len(self.output_size_list), dim=-1)
        output_res = list(torch.split(output, self.output_split_list, dim=-1))
        for i in range(len(self.output_size_list)):
            if self.lora_b[i].sequence_parallel:
                lora_b_input = scatter_to_sequence_parallel_region(lora_b_inputs[i])
            else:
                lora_b_input = lora_b_inputs[i]
            lora_b_output, _ = self.lora_b[i](lora_b_input)
            output_res[i] = output_res[i] + lora_b_output * self.lora_scaling
        res = torch.cat(output_res, dim=-1)
        return res, bias

    def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
        for name, module in self._modules.items():
            sub_prefix = f'{prefix}{name}.'
            if not isinstance(module, nn.Dropout):
                if isinstance(module, nn.ModuleList):
                    for i, sub_module in enumerate(module):
                        sub_prefix_moudle = f'{sub_prefix}{i}.'
                        sub_sd = sub_module.sharded_state_dict(sub_prefix_moudle, sharded_offsets,
                                                               metadata)
                        sharded_state_dict.update(sub_sd)
                else:
                    sub_sd = module.sharded_state_dict(sub_prefix, sharded_offsets, metadata)
                    sharded_state_dict.update(sub_sd)
        return sharded_state_dict


class TELayerNormColumnParallelLoRAQKVLinear(TELayerNormColumnParallelLinear):

    def __init__(
        self,
        input_size: int,
        output_size: int,
        *,
        config: GpatchTransformerConfig,
        init_method: Callable,
        gather_output: bool,
        bias: bool,
        skip_bias_add: bool,
        is_expert: bool = False,
        skip_weight_param_allocation: bool = False,
        tp_comm_buffer_name: str = None,
    ):
        super(TELayerNormColumnParallelLoRAQKVLinear, self).__init__(
            input_size=input_size,
            output_size=output_size,
            config=config,
            init_method=init_method,
            gather_output=gather_output,
            bias=bias,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name,
        )
        self.return_layernorm_output = True

        self.tp_size = mpu.get_tensor_model_parallel_world_size()
        self.query_projection_size = config.kv_channels * config.num_attention_heads
        self.kv_projection_size = config.kv_channels * config.num_query_groups
        self.hidden_size_per_attention_head = divide(self.query_projection_size,
                                                     config.num_attention_heads)
        self.hidden_size_per_attention_head = divide(self.query_projection_size,
                                                     config.num_attention_heads)
        self.num_attention_heads_per_partition = divide(config.num_attention_heads, self.tp_size)
        self.num_query_groups_per_partition = divide(config.num_query_groups, self.tp_size)
        output_size_list = [
            self.query_projection_size,
            self.kv_projection_size,
            self.kv_projection_size,
        ]

        self.lora_r = config.lora_r
        self.lora_alpha = config.lora_alpha
        self.lora_scaling = self.lora_alpha / self.lora_r
        self.output_size_list = output_size_list if len(output_size_list) > 0 else [output_size]
        assert sum(
            self.output_size_list) == output_size, f'output_size_list must be equal to output_size'
        self.tp_size = mpu.get_tensor_model_parallel_world_size()
        self.output_split_list = [ele // self.tp_size for ele in self.output_size_list]

        self.lora_a = TEColumnParallelLinear(
            input_size=input_size,
            output_size=self.lora_r * len(self.output_size_list),
            config=config,
            init_method=init_method,
            gather_output=False,
            bias=False,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_norma' if tp_comm_buffer_name is not None else None,
        )
        self.lora_b = torch.nn.ModuleList()
        for lora_output_size in self.output_size_list:
            self.lora_b.append(
                TEColumnParallelLinear(
                    input_size=self.lora_r,
                    output_size=lora_output_size,
                    config=config,
                    init_method=init_method,
                    gather_output=gather_output,
                    bias=False,
                    skip_bias_add=skip_bias_add,
                    is_expert=is_expert,
                    skip_weight_param_allocation=skip_weight_param_allocation,
                    tp_comm_buffer_name=tp_comm_buffer_name +
                    '_normb' if tp_comm_buffer_name is not None else None,
                ))

    @override
    def forward(self, x):
        if self.return_bias:
            output, bias, norm = super().forward(x)
        else:
            (output, norm), bias = super().forward(x)
        lora_output, _ = self.lora_a(norm)
        lora_output = gather_from_tensor_model_parallel_region(lora_output)

        lora_b_inputs = torch.chunk(lora_output, len(self.output_size_list), dim=-1)
        lora_outputs = []
        for i in range(len(self.output_size_list)):
            if self.lora_b[i].sequence_parallel:
                lora_b_input = scatter_to_sequence_parallel_region(lora_b_inputs[i])
            else:
                lora_b_input = lora_b_inputs[i]
            lora_b_output, _ = self.lora_b[i](lora_b_input)
            lora_outputs.append(lora_b_output * self.lora_scaling)

                               
        assert len(lora_outputs) == 3
                                                 
        lora_q_shape = lora_outputs[0].size()[:-1] + (self.num_query_groups_per_partition, \
                        (self.num_attention_heads_per_partition // self.num_query_groups_per_partition) \
                        * self.hidden_size_per_attention_head)
                                         
        lora_kv_shape = lora_outputs[1].size()[:-1] + (self.num_query_groups_per_partition,
                                                       self.hidden_size_per_attention_head)
        lora_outputs[0] = lora_outputs[0].reshape(*lora_q_shape)
        lora_outputs[1] = lora_outputs[1].reshape(*lora_kv_shape)
        lora_outputs[2] = lora_outputs[2].reshape(*lora_kv_shape)
        lora_output = torch.cat(lora_outputs, dim=3).reshape(*output.size())

        output = output + lora_output
        return output, bias

    def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
        for name, module in self._modules.items():
            sub_prefix = f'{prefix}{name}.'
            if not isinstance(module, nn.Dropout):
                if isinstance(module, nn.ModuleList):
                    for i, sub_module in enumerate(module):
                        sub_prefix_moudle = f'{sub_prefix}{i}.'
                        sub_sd = sub_module.sharded_state_dict(sub_prefix_moudle, sharded_offsets,
                                                               metadata)
                        sharded_state_dict.update(sub_sd)
                else:
                    sub_sd = module.sharded_state_dict(sub_prefix, sharded_offsets, metadata)
                    sharded_state_dict.update(sub_sd)
        return sharded_state_dict


class ColumnParallelLoRALinear(ColumnParallelLinear):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        *,
        config: GpatchTransformerConfig,
        init_method: Callable,
        gather_output: bool,
        bias: bool,
        skip_bias_add: bool,
        is_expert: bool = False,
        skip_weight_param_allocation: bool = False,
        tp_comm_buffer_name: str = None,
    ):
        super(ColumnParallelLoRALinear, self).__init__(
            input_size=input_size,
            output_size=output_size,
            config=config,
            init_method=init_method,
            gather_output=gather_output,
            bias=bias,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name,
        )
        self.lora_r = config.lora_r
        self.lora_alpha = config.lora_alpha
        self.lora_scaling = self.lora_alpha / self.lora_r

        self.lora_a = ColumnParallelLinear(
            input_size=input_size,
            output_size=self.lora_r,
            config=config,
            init_method=init_method,
            gather_output=False,
            bias=False,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_cola' if tp_comm_buffer_name is not None else None,
        )
        self.lora_b = ColumnParallelLinear(
            input_size=self.lora_r,
            output_size=output_size,
            config=config,
            init_method=init_method,
            gather_output=gather_output,
            bias=False,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_colb' if tp_comm_buffer_name is not None else None,
        )

    @override
    def forward(self, x):
        output, bias = super().forward(x)
        lora_output, _ = self.lora_a(x)
        lora_output = gather_from_tensor_model_parallel_region(lora_output)
        if self.config.sequence_parallel:
            lora_output = scatter_to_sequence_parallel_region(lora_output)
        lora_output, _ = self.lora_b(lora_output)

        output = output + lora_output * self.lora_scaling
        return output, bias

    def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
        for name, module in self._modules.items():
            sub_prefix = f'{prefix}{name}.'
            if not isinstance(module, nn.Dropout):
                sub_sd = module.sharded_state_dict(sub_prefix, sharded_offsets, metadata)
                sharded_state_dict.update(sub_sd)
        return sharded_state_dict


class RowParallelLoRALinear(RowParallelLinear):

    def __init__(
        self,
        input_size: int,
        output_size: int,
        *,
        config: GpatchTransformerConfig,
        init_method: Callable,
        bias: bool,
        input_is_parallel: bool,
        skip_bias_add: bool,
        is_expert: bool = False,
        tp_comm_buffer_name: str = None,
    ):
        super(RowParallelLoRALinear, self).__init__(
            input_size=input_size,
            output_size=output_size,
            config=config,
            init_method=init_method,
            bias=bias,
            input_is_parallel=input_is_parallel,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            tp_comm_buffer_name=tp_comm_buffer_name,
        )
        self.lora_r = config.lora_r
        self.lora_alpha = config.lora_alpha
        self.lora_scaling = self.lora_alpha / self.lora_r
        self.input_is_parallel = input_is_parallel

        lora_a_config = deepcopy(config)
        lora_a_config.sequence_parallel = False
        self.lora_a = RowParallelLinear(
            input_size=input_size,
            output_size=self.lora_r,
            config=lora_a_config,
            init_method=init_method,
            bias=False,
            input_is_parallel=input_is_parallel,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_rowa' if tp_comm_buffer_name is not None else None,
        )
        self.lora_b = RowParallelLinear(
            input_size=self.lora_r,
            output_size=output_size,
            config=config,
            init_method=init_method,
            bias=False,
            input_is_parallel=True,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_rowb' if tp_comm_buffer_name is not None else None,
        )

    @override
    def forward(self, x):

        output, bias = super().forward(x)
        lora_output, _ = self.lora_a(x)
        lora_output = scatter_to_tensor_model_parallel_region(lora_output)


        lora_output, _ = self.lora_b(lora_output)


        output = output + lora_output * self.lora_scaling
        return output, bias

    def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
        for name, module in self._modules.items():
            sub_prefix = f'{prefix}{name}.'
            if not isinstance(module, nn.Dropout):
                sub_sd = module.sharded_state_dict(sub_prefix, sharded_offsets, metadata)
                sharded_state_dict.update(sub_sd)
        return sharded_state_dict


class ColumnParallelLoRAQKVLinear(ColumnParallelLinear):

    def __init__(
        self,
        input_size: int,
        output_size: int,
        *,
        config: GpatchTransformerConfig,
        init_method: Callable,
        gather_output: bool,
        bias: bool,
        skip_bias_add: bool,
        is_expert: bool = False,
        skip_weight_param_allocation: bool = False,
        tp_comm_buffer_name: str = None,
    ):
        super(ColumnParallelLoRAQKVLinear, self).__init__(
            input_size=input_size,
            output_size=output_size,
            config=config,
            init_method=init_method,
            gather_output=gather_output,
            bias=bias,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name,
        )
       
        self.tp_size = mpu.get_tensor_model_parallel_world_size()
        self.query_projection_size = config.kv_channels * config.num_attention_heads
        self.kv_projection_size = config.kv_channels * config.num_query_groups
        self.hidden_size_per_attention_head = divide(self.query_projection_size,
                                                     config.num_attention_heads)
        self.hidden_size_per_attention_head = divide(self.query_projection_size,
                                                     config.num_attention_heads)
        self.num_attention_heads_per_partition = divide(config.num_attention_heads, self.tp_size)
        self.num_query_groups_per_partition = divide(config.num_query_groups, self.tp_size)
        output_size_list = [
            self.query_projection_size,
            self.kv_projection_size,
            self.kv_projection_size,
        ]

        self.lora_r = config.lora_r
        self.lora_alpha = config.lora_alpha
        self.lora_scaling = self.lora_alpha / self.lora_r
        self.output_size_list = output_size_list if len(output_size_list) > 0 else [output_size]
        assert sum(
            self.output_size_list) == output_size, f'output_size_list must be equal to output_size'
        self.tp_size = mpu.get_tensor_model_parallel_world_size()
        self.output_split_list = [ele // self.tp_size for ele in self.output_size_list]

        self.lora_a = ColumnParallelLinear(
            input_size=input_size,
            output_size=self.lora_r * len(self.output_size_list),
            config=config,
            init_method=init_method,
            gather_output=False,
            bias=False,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_norma' if tp_comm_buffer_name is not None else None,
        )
        self.lora_b = torch.nn.ModuleList()
        for lora_output_size in self.output_size_list:
            self.lora_b.append(
                ColumnParallelLinear(
                    input_size=self.lora_r,
                    output_size=lora_output_size,
                    config=config,
                    init_method=init_method,
                    gather_output=gather_output,
                    bias=False,
                    skip_bias_add=skip_bias_add,
                    is_expert=is_expert,
                    skip_weight_param_allocation=skip_weight_param_allocation,
                    tp_comm_buffer_name=tp_comm_buffer_name +
                    '_normb' if tp_comm_buffer_name is not None else None,
                ))

    @override
    def forward(self, x):
        output, bias = super().forward(x)

        lora_output, _ = self.lora_a(x)
        lora_output = gather_from_tensor_model_parallel_region(lora_output)

        lora_b_inputs = torch.chunk(lora_output, len(self.output_size_list), dim=-1)
        lora_outputs = []
        for i in range(len(self.output_size_list)):
            if self.lora_b[i].sequence_parallel:
                lora_b_input = scatter_to_sequence_parallel_region(lora_b_inputs[i])
            else:
                lora_b_input = lora_b_inputs[i]
            lora_b_output, _ = self.lora_b[i](lora_b_input)
            lora_outputs.append(lora_b_output * self.lora_scaling)

                               
        assert len(lora_outputs) == 3
                                                 
        lora_q_shape = lora_outputs[0].size()[:-1] + (self.num_query_groups_per_partition, \
                        (self.num_attention_heads_per_partition // self.num_query_groups_per_partition) \
                        * self.hidden_size_per_attention_head)
                                         
        lora_kv_shape = lora_outputs[1].size()[:-1] + (self.num_query_groups_per_partition,
                                                       self.hidden_size_per_attention_head)
        lora_outputs[0] = lora_outputs[0].reshape(*lora_q_shape)
        lora_outputs[1] = lora_outputs[1].reshape(*lora_kv_shape)
        lora_outputs[2] = lora_outputs[2].reshape(*lora_kv_shape)
        lora_output = torch.cat(lora_outputs, dim=3).reshape(*output.size())

        output = output + lora_output
        return output, bias

    def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
        for name, module in self._modules.items():
            sub_prefix = f'{prefix}{name}.'
            if not isinstance(module, nn.Dropout):
                if isinstance(module, nn.ModuleList):
                    for i, sub_module in enumerate(module):
                        sub_prefix_moudle = f'{sub_prefix}{i}.'
                        sub_sd = sub_module.sharded_state_dict(sub_prefix_moudle, sharded_offsets,
                                                               metadata)
                        sharded_state_dict.update(sub_sd)
                else:
                    sub_sd = module.sharded_state_dict(sub_prefix, sharded_offsets, metadata)
                    sharded_state_dict.update(sub_sd)
        return sharded_state_dict


class ColumnParallelLoRAMergeLinear(ColumnParallelLinear):

    def __init__(
        self,
        input_size: int,
        output_size: int,
        *,
        config: GpatchTransformerConfig,
        init_method: Callable,
        gather_output: bool,
        bias: bool,
        skip_bias_add: bool,
        is_expert: bool = False,
        skip_weight_param_allocation: bool = False,
        tp_comm_buffer_name: str = None,
    ):
        assert config.gated_linear_unit, "use class ColumnParallelLoRALinear when gated_linear_unit is False"
        super(ColumnParallelLoRAMergeLinear, self).__init__(
            input_size=input_size,
            output_size=output_size,
            config=config,
            init_method=init_method,
            gather_output=gather_output,
            bias=bias,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name,
        )

        self.lora_r = config.lora_r
        self.lora_alpha = config.lora_alpha
        self.lora_scaling = self.lora_alpha / self.lora_r
        self.output_size_list = [output_size // 2 for _ in range(2)]
        assert sum(
            self.output_size_list) == output_size, f'output_size_list must be equal to output_size'
        self.tp_size = mpu.get_tensor_model_parallel_world_size()
        self.output_split_list = [ele // self.tp_size for ele in self.output_size_list]

        self.lora_a = ColumnParallelLinear(
            input_size=input_size,
            output_size=self.lora_r * len(self.output_size_list),
            config=config,
            init_method=init_method,
            gather_output=False,
            bias=False,
            skip_bias_add=skip_bias_add,
            is_expert=is_expert,
            skip_weight_param_allocation=skip_weight_param_allocation,
            tp_comm_buffer_name=tp_comm_buffer_name +
            '_norma' if tp_comm_buffer_name is not None else None,
        )
        self.lora_b = torch.nn.ModuleList()
        for lora_output_size in self.output_size_list:
            self.lora_b.append(
                ColumnParallelLinear(
                    input_size=self.lora_r,
                    output_size=lora_output_size,
                    config=config,
                    init_method=init_method,
                    gather_output=gather_output,
                    bias=False,
                    skip_bias_add=skip_bias_add,
                    is_expert=is_expert,
                    skip_weight_param_allocation=skip_weight_param_allocation,
                    tp_comm_buffer_name=tp_comm_buffer_name +
                    '_normb' if tp_comm_buffer_name is not None else None,
                ))

    @override
    def forward(self, x):
        output, bias = super().forward(x)
        lora_output, _ = self.lora_a(x)
        lora_output = gather_from_tensor_model_parallel_region(lora_output)

        lora_b_inputs = torch.chunk(lora_output, len(self.output_size_list), dim=-1)
        output_res = list(torch.split(output, self.output_split_list, dim=-1))
        for i in range(len(self.output_size_list)):
            if self.lora_b[i].sequence_parallel:
                lora_b_input = scatter_to_sequence_parallel_region(lora_b_inputs[i])
            else:
                lora_b_input = lora_b_inputs[i]
            lora_b_output, _ = self.lora_b[i](lora_b_input)
            output_res[i] = output_res[i] + lora_b_output * self.lora_scaling
        res = torch.cat(output_res, dim=-1)
        return res, bias

    def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
        for name, module in self._modules.items():
            sub_prefix = f'{prefix}{name}.'
            if not isinstance(module, nn.Dropout):
                if isinstance(module, nn.ModuleList):
                    for i, sub_module in enumerate(module):
                        sub_prefix_moudle = f'{sub_prefix}{i}.'
                        sub_sd = sub_module.sharded_state_dict(sub_prefix_moudle, sharded_offsets,
                                                               metadata)
                        sharded_state_dict.update(sub_sd)
                else:
                    sub_sd = module.sharded_state_dict(sub_prefix, sharded_offsets, metadata)
                    sharded_state_dict.update(sub_sd)
        return sharded_state_dict
