from fairscale.nn.model_parallel.layers import (
    RowParallelLinear,
    ColumnParallelLinear,
    ParallelEmbedding
)

from src.modeling import RMSNorm
from src.modeling_abstract_hf import (
    AbstractAttentionHF,
    AbstractFeedForwardHF,
    AbstractTransformerBlockHF,
    AbstractLLaMAHF,
    LlamaRotaryEmbedding
)
from src.modeling_args import ModelArgs


class AttentionHF(AbstractAttentionHF):
    def __init__(self, args: ModelArgs):
        super().__init__(args)

        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
            )
        self.wk = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
            )
        self.wv = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
            )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
            )
        self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)


class FeedForwardHF(AbstractFeedForwardHF):
    def __init__(self, args: ModelArgs):
        super().__init__(args)
        self.w1 = ColumnParallelLinear(
            self.dim, self.hidden_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x
        )
        self.w2 = RowParallelLinear(
            self.hidden_dim, self.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x
        )
        self.w3 = ColumnParallelLinear(
            self.dim, self.hidden_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x
        )


class TransformerBlockHF(AbstractTransformerBlockHF):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__(layer_id, args)
        self.attention = AttentionHF(args)
        self.feed_forward = FeedForwardHF(args)
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)


class LLaMAHF(AbstractLLaMAHF):
    def __init__(self, args: ModelArgs):
        super().__init__(args)
        for layer_id in range(args.n_layers):
            self.layers.append(TransformerBlockHF(layer_id, args))

        self.tok_embeddings = ParallelEmbedding(
            args.vocab_size, args.dim, init_method=lambda x: x
        )
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = ColumnParallelLinear(
            args.dim, args.vocab_size, bias=False, init_method=lambda x: x
        )
