# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import datetime
import os
from logging import Logger

import datasets
import torch
import torch.distributed as dist
from torch import nn
from transformers import LlamaTokenizerFast, Trainer, default_data_collator
import transformers
from train_utils.fsdp_trainer import FSDPTrainer
from train_utils.main import prepare_model
from train_utils.modeling_llama_quant import LlamaForCausalLM as LlamaForCausalLMQuant
from train_utils.optimizer import SGDG, AdamG
from spin_utils.data_utils import CustomJsonDataset
from spin_utils.hadamard_utils import random_hadamard_matrix
from spin_utils.process_args import process_args_ptq
from spin_utils.utils import get_local_rank, get_logger, pt_fsdp_state_dict

log: Logger = get_logger("spinquant")


class RotateModule(nn.Module):
    def __init__(self, R_init):
        super(RotateModule, self).__init__()
        self.weight = nn.Parameter(R_init.to(torch.float32).to(torch.device("cuda")))

    def forward(self, x, transpose=False):
        if transpose:
            return x @ self.weight
        else:
            return self.weight @ x


class BlockRotateModule(nn.Module):
    def __init__(self, hidden_size, block_size):
        super(BlockRotateModule, self).__init__()
        self.hidden_size = hidden_size
        self.block_size = block_size
        self.num_blocks = hidden_size // block_size
        self.block_r = []
        for _ in range(self.num_blocks):
            self.block_r.append(
                nn.Parameter(random_hadamard_matrix(block_size, "cuda").float())
            )
        self.weight = self.get_weight()

    def forward(self, x, transpose=False):
        self.weight = self.get_weight().to(x.device)
        if transpose:
            return x @ self.weight
        else:
            return self.weight @ x

    def get_weight(self):
        weight = torch.block_diag(*self.block_r)
        return weight


def train() -> None:
    dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8))
    model_args, training_args, ptq_args = process_args_ptq()
    local_rank = get_local_rank()

    log.info("the rank is {}".format(local_rank))
    torch.distributed.barrier()

    config = transformers.AutoConfig.from_pretrained(
        model_args.input_model, token=model_args.access_token
    )

    # Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens
    process_word_embeddings = False
    if config.tie_word_embeddings:
        config.tie_word_embeddings = False
        process_word_embeddings = True
    dtype = torch.bfloat16 if training_args.bf16 else torch.float16
    if 'mistral' in model_args.input_model.lower():
        from train_utils.modeling_mistral_quant import MistralForCausalLM as MistralForCausalLMQuant
        model = MistralForCausalLMQuant.from_pretrained(
            pretrained_model_name_or_path=model_args.input_model,
            config=config,
            torch_dtype=dtype,
            token=model_args.access_token,
        )
    elif 'qwen2' in model_args.input_model.lower():
        from train_utils.modeling_qwen2_quant import Qwen2ForCausalLM as Qwen2ForCausalLMQuant
        model = Qwen2ForCausalLMQuant.from_pretrained(
            pretrained_model_name_or_path=model_args.input_model,
            config=config,
            torch_dtype=dtype,
            token=model_args.access_token,
        )
    else:
        model = LlamaForCausalLMQuant.from_pretrained(
            pretrained_model_name_or_path=model_args.input_model,
            config=config,
            torch_dtype=dtype,
            token=model_args.access_token,
        )
    if process_word_embeddings:
        model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone()

    model = prepare_model(ptq_args, model)
    for param in model.parameters():
        param.requires_grad = False
    if ptq_args.block_rotation:
        R1 = random_hadamard_matrix(32, "cuda")
        R1 = torch.block_diag(*[R1 for _ in range(model.config.hidden_size // 32)])
        model.R1 = RotateModule(R1)
        for i in range(model.config.num_hidden_layers):
            # Each head dim = 128 for Llama model
            R2 = random_hadamard_matrix(32, "cuda")
            R2 = torch.block_diag(*[R2 for _ in range(model.config.hidden_size //
                                  model.config.num_attention_heads // 32)])
            model.model.layers[i].self_attn.R2 = RotateModule(R2)
    else:
        R1 = random_hadamard_matrix(model.config.hidden_size, "cuda")
        model.R1 = RotateModule(R1)
        for i in range(model.config.num_hidden_layers):
            # Each head dim = 128 for Llama model
            R2 = random_hadamard_matrix(
                model.config.hidden_size // model.config.num_attention_heads, "cuda"
            )
            model.model.layers[i].self_attn.R2 = RotateModule(R2)
    if local_rank == 0:
        log.info("Model init completed for training {}".format(model))
        log.info("Start to load tokenizer...")
    if 'llama' in model_args.input_model.lower():
        tokenizer = LlamaTokenizerFast.from_pretrained(
            pretrained_model_name_or_path=model_args.input_model,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=True,
            add_eos_token=False,
            add_bos_token=False,
            token=model_args.access_token,
        )
    elif 'mistral' in model_args.input_model.lower():
        # from transformers import MistralTokenizerFast
        tokenizer = LlamaTokenizerFast.from_pretrained(
            pretrained_model_name_or_path=model_args.input_model,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=True,
            add_eos_token=False,
            add_bos_token=False,
            token=model_args.access_token,
        )
    elif 'qwen2' in model_args.input_model.lower():
        from transformers import Qwen2TokenizerFast
        tokenizer = Qwen2TokenizerFast.from_pretrained(
            pretrained_model_name_or_path=model_args.input_model,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=True,
            add_eos_token=False,
            add_bos_token=False,
            token=model_args.access_token,
        )
    log.info("Complete tokenizer loading...")
    model.config.use_cache = False
    calibration_datasets = datasets.load_dataset(
        "Salesforce/wikitext", "wikitext-2-raw-v1"
    )

    TARGET_FINAL_SAMPLES = None

    train_data = CustomJsonDataset(
        calibration_datasets["train"],
        tokenizer,
        block_size=min(training_args.model_max_length, 2048),
        max_samples=TARGET_FINAL_SAMPLES,
    )

    if local_rank == 0:
        log.info(
            f"Final training dataset contains {len(train_data)} samples of length {min(training_args.model_max_length, 2048)}")

    trainable_parameters = [model.R1.weight] + [
        model.model.layers[i].self_attn.R2.weight
        for i in range(model.config.num_hidden_layers)
    ]
    model.seqlen = training_args.model_max_length
    if not ptq_args.block_rotation:
        optimizer = SGDG(trainable_parameters, lr=training_args.learning_rate, stiefel=True)
    else:
        optimizer = AdamG(trainable_parameters, lr=training_args.learning_rate, stiefel=True)
    MyTrainer = Trainer
    # Use FSDP for 70B rotation training
    if training_args.fsdp != "" and training_args.fsdp != []:
        MyTrainer = FSDPTrainer

    trainer = MyTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_data,
        eval_dataset=None,
        data_collator=default_data_collator,
        optimizers=(optimizer, None),
    )
    torch.distributed.barrier()

    trainer.train()
    if training_args.fsdp != "" and training_args.fsdp != []:
        cpu_state = pt_fsdp_state_dict(trainer.model)
    else:
        cpu_state = trainer.model.state_dict()

    R_dict = {
        key.replace(".weight", ""): value
        for key, value in cpu_state.items()
        if "R1.weight" in key or "self_attn.R2" in key
    }
    if local_rank == 0:
        os.makedirs(model_args.output_rotation_path, exist_ok=True)
        path = os.path.join(model_args.output_rotation_path, "R.bin")
        torch.save(
            R_dict,
            path,
        )
    dist.barrier()


if __name__ == "__main__":
    train()
