# This file is originally from
# https://github.com/meta-llama/llama-recipes/blob/ed87ba4abe97c16bb81833598f3bb793ff0e009f/src/llama_recipes/finetuning.py
# and is udpated by the authors of the FoVer paper.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import dataclasses
import os
import random
from collections import Counter
from warnings import warn

import fire
import numpy as np
import torch
import torch.optim as optim
from torch.distributed.fsdp import StateDictType
from accelerate.utils import is_xpu_available

from llama_recipes.configs import (
    fsdp_config as FSDP_CONFIG,
    quantization_config as QUANTIZATION_CONFIG,
    # train_config as TRAIN_CONFIG,
)
from src.llm.llama_recipes_updated.configs import train_config as TRAIN_CONFIG
# from llama_recipes.data.concatenator import ConcatDataset
from src.llm.llama_recipes_updated.data.concatenator import NoPartialSentenceConcatDataset

from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing

from llama_recipes.utils import fsdp_auto_wrap_policy
from llama_recipes.utils.config_utils import (
    check_fsdp_config,
    generate_dataset_config,
    generate_peft_config,
    get_dataloader_kwargs,
    update_config,
)
from llama_recipes.utils.dataset_utils import (
    get_custom_data_collator,
    get_preprocessed_dataset,
)

from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
# originally importing llama_recipes.utils.train_utils
from src.llm.llama_recipes_updated.utils.train_utils import (
    clear_gpu_cache,
    freeze_transformer_layers,
    # freeze_LLM_only,
    get_policies,
    print_model_size,
    # print_frozen_model_status,
    setup,
    setup_environ_flags,
    get_save_folder,
    train,
)

from peft import get_peft_model, PeftModel
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.optim.lr_scheduler import StepLR, LinearLR
from transformers import (
    AutoConfig,
    AutoProcessor,
    AutoTokenizer,
    BitsAndBytesConfig,
    LlamaForCausalLM,
    MllamaForConditionalGeneration,
    AutoModelForCausalLM,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mllama.modeling_mllama import (
    MllamaCrossAttentionDecoderLayer,
    MllamaSelfAttentionDecoderLayer,
    MllamaVisionEncoderLayer,
)


def setup_wandb(train_config, fsdp_config, **kwargs):
    try:
        import wandb
    except ImportError:
        raise ImportError(
            "You are trying to use wandb which is not currently installed. "
            "Please install it using pip install wandb"
        )
    from llama_recipes.configs import wandb_config as WANDB_CONFIG

    wandb_config = WANDB_CONFIG()
    update_config(wandb_config, **kwargs)
    init_dict = dataclasses.asdict(wandb_config)
    run = wandb.init(**init_dict)
    run.config.update(train_config)
    run.config.update(fsdp_config, allow_val_change=True)
    return run


def main(**kwargs):
    # Update the configuration for the training and sharding process
    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
    update_config((train_config, fsdp_config), **kwargs)
    # Set the seeds for reproducibility
    if is_xpu_available():
        torch.xpu.manual_seed(train_config.seed)
    torch.manual_seed(train_config.seed)
    random.seed(train_config.seed)
    np.random.seed(train_config.seed)

    if train_config.enable_fsdp:
        setup()
        # torchrun specific
        local_rank = int(os.environ["LOCAL_RANK"])
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])

    if torch.distributed.is_initialized():
        if is_xpu_available():
            torch.xpu.set_device(local_rank)
        elif torch.cuda.is_available():
            torch.cuda.set_device(local_rank)
        clear_gpu_cache(local_rank)
        setup_environ_flags(rank)

    wandb_run = None

    if train_config.use_wandb:
        if not train_config.enable_fsdp or rank == 0:
            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)

    # setting quantization configs
    bnb_config = None
    if train_config.quantization:
        if type(train_config.quantization) == type(True):
            warn(
                "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
                FutureWarning,
            )
            train_config.quantization = "8bit"

        if train_config.quantization == "8bit" and train_config.enable_fsdp:
            raise ValueError(
                "8bit quantization is not supported with FSDP, please use 4bit quantization"
            )

        quant_config = QUANTIZATION_CONFIG()
        update_config(quant_config, **kwargs)
        bnb_config = quant_config.create_bnb_config(train_config.quantization)

    # Load the pre-trained model and setup its configuration
    use_cache = False if train_config.enable_fsdp else None
    config = AutoConfig.from_pretrained(train_config.model_name)
    if config.model_type == "mllama":
        is_vision = True
        model = MllamaForConditionalGeneration.from_pretrained(
            train_config.model_name,
            quantization_config=bnb_config,
            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
            device_map=(
                "auto"
                if train_config.quantization and not train_config.enable_fsdp
                else None
            ),
            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
        )
        processor = AutoProcessor.from_pretrained(
            train_config.model_name
            if train_config.tokenizer_name is None
            else train_config.tokenizer_name
        )
        processor.tokenizer.padding_side = "right"
        model.supports_gradient_checkpointing = True
        model.language_model.supports_gradient_checkpointing = True
    elif config.model_type == "llama":
        is_vision = False
        model = LlamaForCausalLM.from_pretrained(
            train_config.model_name,
            quantization_config=bnb_config,
            use_cache=use_cache,
            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
            device_map=(
                "auto"
                if train_config.quantization and not train_config.enable_fsdp
                else None
            ),
            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
        )
    elif config.model_type in ["qwen2", "gemma2"]:
        is_vision = False
        model = AutoModelForCausalLM.from_pretrained(
            train_config.model_name,
            quantization_config=bnb_config,
            use_cache=use_cache,
            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
            device_map=(
                "auto"
                if train_config.quantization and not train_config.enable_fsdp
                else None
            ),
            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
        )
    else:
        raise ValueError(
            f"Model type {config.model_type} is not supported."
        )
    # Load the tokenizer and add special tokens
    tokenizer = AutoTokenizer.from_pretrained(
        train_config.model_name
        if train_config.tokenizer_name is None
        else train_config.tokenizer_name
    )
    if not tokenizer.pad_token_id:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # If there is a mismatch between tokenizer vocab size and embedding matrix,
    # throw a warning and then expand the embedding matrix
    if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
        raise ValueError("We do not use this case")
        
        print(
            "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
        )
        model.resize_token_embeddings(len(tokenizer))

    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
    
    # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
    if (
        train_config.enable_fsdp
        and fsdp_config.pure_bf16
        and not train_config.quantization
    ):
        model.to(torch.bfloat16)

    if train_config.use_peft:
        # Load the pre-trained peft model checkpoint and setup its configuration
        if train_config.from_peft_checkpoint:
            model = PeftModel.from_pretrained(
                model, train_config.from_peft_checkpoint, is_trainable=True
            )
            peft_config = model.peft_config
        # Generate the peft config and start fine-tuning from original model
        else:
            peft_config = generate_peft_config(train_config, kwargs)
            model = get_peft_model(model, peft_config)
        if wandb_run:
            wandb_run.config.update(peft_config)
        model.print_trainable_parameters()

    hsdp_device_mesh_plan = None
    if (
        fsdp_config.hsdp
        and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
    ):
        hsdp_device_mesh_plan = hsdp_device_mesh(
            replica_group_size=fsdp_config.replica_group_size,
            sharding_group_size=fsdp_config.sharding_group_size,
        )
        print("HSDP device mesh is ready")

    # setting up FSDP if enable_fsdp is enabled
    if train_config.enable_fsdp:
        check_fsdp_config(fsdp_config)
        
        # commented out by the FoVer author because train_config.freeze_layers is not supported by the current version of llama-recipes

        # if not train_config.use_peft and train_config.freeze_layers:
        #     freeze_transformer_layers(model, train_config.num_freeze_layers)
        #     # print model size and frozen layers after freezing layers
        #     print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
            
        # if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
        #     freeze_LLM_only(model)
        #     # print model size and frozen layers after freezing layers
        #     print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
        
        mixed_precision_policy, llama_wrapping_policy = get_policies(fsdp_config, rank)
        if config.model_type in ["mllama", "llama"]:
            wrapping_policy = llama_wrapping_policy
        else:
            import functools
            from torch.distributed.fsdp.wrap import (
                transformer_auto_wrap_policy,
            )
            # we need to change the wrapping policy for models other than llama and mllama
            if config.model_type == "qwen2":
                from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
                wrapping_policy = functools.partial(
                    transformer_auto_wrap_policy,
                    transformer_layer_cls=set([Qwen2DecoderLayer])
                )
            elif config.model_type == "gemma2":
                from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer
                wrapping_policy = functools.partial(
                    transformer_auto_wrap_policy,
                    transformer_layer_cls=set([Gemma2DecoderLayer])
                )
            else:
                raise ValueError(
                    f"Model type {config.model_type} is not supported."
                )
        
        if train_config.use_peft:
            if config.model_type in ["mllama", "llama"]:
                # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
                if is_vision:
                    my_auto_wrapping_policy = fsdp_auto_wrap_policy(
                        model,
                        [
                            MllamaSelfAttentionDecoderLayer,
                            MllamaCrossAttentionDecoderLayer,
                            MllamaVisionEncoderLayer,
                        ],
                    )
                else:
                    # Create the FSDP wrapper for LlamaDecoderLayer in text models
                    my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
            else:
                if config.model_type == "qwen2":
                    from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
                    my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [Qwen2DecoderLayer])
                elif config.model_type == "gemma2":
                    from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer
                    my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [Gemma2DecoderLayer])
                else:
                    raise ValueError(
                        f"Model type {config.model_type} is not supported."
                    )
        else:
            my_auto_wrapping_policy = None
        
        device_id = 0
        if is_xpu_available():
            device_id = torch.xpu.current_device()
        elif torch.cuda.is_available():
            device_id = torch.cuda.current_device()
        
        if False:  # train_config.freeze_LLM_only:
            use_orig_params = True
        else:
            use_orig_params = False
        model = FSDP(
            model,
            auto_wrap_policy=(
                my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
            ),
            cpu_offload=(
                CPUOffload(offload_params=True)
                if fsdp_config.fsdp_cpu_offload
                else None
            ),
            mixed_precision=(
                mixed_precision_policy if not fsdp_config.pure_bf16 else None
            ),
            sharding_strategy=fsdp_config.sharding_strategy,
            device_mesh=hsdp_device_mesh_plan,
            device_id=device_id,
            limit_all_gathers=True,
            sync_module_states=train_config.low_cpu_fsdp,
            param_init_fn=(
                (
                    lambda module: module.to_empty(
                        device=torch.device("cuda"), recurse=False
                    )
                )
                if train_config.low_cpu_fsdp and rank != 0
                else None
            ),
            use_orig_params=use_orig_params,
        )
        if fsdp_config.fsdp_activation_checkpointing:
            model.enable_input_require_grads()
            model.gradient_checkpointing_enable()
            apply_fsdp_checkpointing(model)
    elif not train_config.quantization and not train_config.enable_fsdp:
        if is_xpu_available():
            model.to("xpu:0")
        elif torch.cuda.is_available():
            model.to("cuda")
    dataset_config = generate_dataset_config(train_config, kwargs)
    if is_vision:
        dataset_processer = processor
    else:
        dataset_processer = tokenizer
    
    # Load and preprocess the dataset for training and validation

    if rank == 0:
        print(dataset_processer)

    dataset_train = get_preprocessed_dataset(
        dataset_processer,
        dataset_config,
        split="train",
    )
    if not train_config.enable_fsdp or rank == 0:
        print(f"--> Training Set Length = {len(dataset_train)}")

    dataset_val = get_preprocessed_dataset(
        dataset_processer,
        dataset_config,
        split="test",
    )
    if not train_config.enable_fsdp or rank == 0:
        print(f"--> Validation Set Length = {len(dataset_val)}")

    if train_config.batching_strategy == "packing":
        if is_vision:
            raise ValueError("Packing is not supported for vision datasets")
        else:
            dataset_train = NoPartialSentenceConcatDataset(
                dataset_train, chunk_size=train_config.context_length,
                pad_token_id=tokenizer.pad_token_id
            )

    train_dl_kwargs = get_dataloader_kwargs(
        train_config, dataset_train, dataset_processer, "train"
    )
    print("length of dataset_train", len(dataset_train))
    custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
    if custom_data_collator:
        print("custom_data_collator is used")
        train_dl_kwargs["collate_fn"] = custom_data_collator
    # Create DataLoaders for the training and validation dataset
    train_dataloader = torch.utils.data.DataLoader(
        dataset_train,
        num_workers=train_config.num_workers_dataloader,
        pin_memory=True,
        **train_dl_kwargs,
    )
    print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")

    eval_dataloader = None
    if train_config.run_validation:
        if train_config.batching_strategy == "packing":
            if is_vision:
                raise ValueError("Packing is not supported for vision datasets")
            else:
                dataset_val = NoPartialSentenceConcatDataset(
                    dataset_val, chunk_size=train_config.context_length,
                    pad_token_id=tokenizer.pad_token_id
                )

        val_dl_kwargs = get_dataloader_kwargs(
            train_config, dataset_val, dataset_processer, "val"
        )
        if custom_data_collator:
            val_dl_kwargs["collate_fn"] = custom_data_collator

        eval_dataloader = torch.utils.data.DataLoader(
            dataset_val,
            num_workers=train_config.num_workers_dataloader,
            pin_memory=True,
            **val_dl_kwargs,
        )
        print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
        if len(eval_dataloader) == 0:
            raise ValueError(
                f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
            )
        else:
            print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")

    # Initialize the optimizer and learning rate scheduler
    if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
        optimizer = AnyPrecisionAdamW(
            model.parameters(),
            lr=train_config.lr,
            momentum_dtype=torch.bfloat16,
            variance_dtype=torch.bfloat16,
            use_kahan_summation=False,
            weight_decay=train_config.weight_decay,
        )
    elif fsdp_config.optimizer == "AdamW":
        optimizer = optim.AdamW(
            model.parameters(),
            lr=train_config.lr,
            weight_decay=train_config.weight_decay,
        )
    elif fsdp_config.optimizer == "RecAdam":
        from rec_adam import build_rec_adam_optimizer
        optimizer = build_rec_adam_optimizer(
            model=model,
            learning_rate=train_config.lr,
            weight_decay=train_config.weight_decay,
            fisher_coef=train_config.rec_adam_fisher_coef,
        )
    else:
        raise NotImplementedError(f"Optimizer {fsdp_config.optimizer} is not supported")
    
    # scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
    # updated by the FoVer author
    # 50% of the first epoch is warmup
    num_warmup_steps = len(train_dataloader) // 2 // train_config.gradient_accumulation_steps
    # start from 1% of the target learning rate
    scheduler = LinearLR(optimizer, start_factor=0.01, end_factor=1.0,
                         total_iters=num_warmup_steps)
    
    results = train(
        model,
        train_dataloader,
        eval_dataloader,
        tokenizer,
        optimizer,
        scheduler,
        train_config.gradient_accumulation_steps,
        train_config,
        fsdp_config if train_config.enable_fsdp else None,
        local_rank if train_config.enable_fsdp else None,
        rank if train_config.enable_fsdp else None,
        wandb_run,
    )
    if not train_config.enable_fsdp or rank == 0:
        [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
        if train_config.use_wandb:
            for k, v in results.items():
                wandb_run.summary[k] = v
        
        ###
        # Save the model and optimizer checkpoint
        if fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
            model_saved_folder = get_save_folder(train_config)
            
            # save tokenizer
            tokenizer.save_pretrained(model_saved_folder)
            
            # save optimizer
            torch.save(optimizer.state_dict(), os.path.join(model_saved_folder, "optimizer.pt"))
            

if __name__ == "__main__":
    fire.Fire(main)
