# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os
from pkg_resources import packaging

import fire
import random
import torch
import torch.optim as optim
from peft import get_peft_model, prepare_model_for_int8_training
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.optim.lr_scheduler import StepLR
from transformers import (
    LlamaForCausalLM,
    LlamaTokenizer,
    LlamaConfig,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
import torch.distributed as dist

from llama_recipes.configs import fsdp_config as FSDP_CONFIG
from llama_recipes.configs import train_config as TRAIN_CONFIG
from llama_recipes.configs import model_config as MODEL_CONFIG
from llama_recipes.data.concatenator import ConcatDataset
from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing

from llama_recipes.utils import fsdp_auto_wrap_policy
from llama_recipes.utils.config_utils import (
    update_config,
    generate_peft_config,
    generate_dataset_config,
    get_dataloader_kwargs,
)
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset

from llama_recipes.utils.train_utils import (
    train,
    evaluation,
    freeze_transformer_layers,
    setup,
    setup_environ_flags,
    clear_gpu_cache,
    print_model_size,
    get_policies
)
from accelerate.utils import is_xpu_available

from lte_model_generator import model_generator_llama

def main(**kwargs):
    # Update the configuration for the training and sharding process
    train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG()
    update_config((train_config, fsdp_config), **kwargs)

    # Set the seeds for reproducibility
    if is_xpu_available():
        torch.xpu.manual_seed(train_config.seed)
    else:
        torch.cuda.manual_seed(train_config.seed)
    torch.manual_seed(train_config.seed)
    random.seed(train_config.seed)

    if train_config.enable_fsdp:
        setup()
        # torchrun specific
        local_rank = int(os.environ["LOCAL_RANK"])
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        print('local_rank', local_rank)
        if local_rank == 0:
            print(fsdp_config)
            print(train_config)

    if torch.distributed.is_initialized():
        if is_xpu_available():
            torch.xpu.set_device(local_rank)
        else:
            torch.cuda.set_device(local_rank)
        clear_gpu_cache(local_rank)
        setup_environ_flags(rank)



    # smoke_test = kwargs['smoke_test']
    # print(smoke_test)

    # Load the pre-trained model and setup its configuration
    use_cache = False if train_config.enable_fsdp else None
    if train_config.enable_fsdp and train_config.low_cpu_fsdp:
        """
        for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
        this avoids cpu oom when loading large models like llama 70B, in which case
        model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
        overhead and currently requires latest nightly.
        """
        v = packaging.version.parse(torch.__version__)
        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
        if not verify_latest_nightly:
            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
                            "please install latest nightly.")
        if rank == 0:
            model = LlamaForCausalLM.from_pretrained(
                train_config.model_name,
                load_in_8bit=True if train_config.quantization else None,
                device_map="auto" if train_config.quantization else None,
                use_cache=use_cache,
            )
        else:
            llama_config = LlamaConfig.from_pretrained(train_config.model_name)
            llama_config.use_cache = use_cache
            with torch.device("meta"):
                model = LlamaForCausalLM(llama_config)

    else:
        # model = LlamaForCausalLM.from_pretrained(
        #     train_config.model_name,
        #     load_in_8bit=True if train_config.quantization else None,
        #     device_map="auto" if train_config.quantization else None,
        #     use_cache=use_cache,
        # )
        # model.config.kla=False

        for k, v in kwargs.items():
            setattr(model_config, k, v)

        print(model_config)

        model = model_generator_llama(model_config)

    if train_config.enable_fsdp and train_config.use_fast_kernels:
        """
        For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
        using of Flash Attention or Xformer memory-efficient kernels
        based on the hardware being used. This would speed up fine-tuning.
        """
        try:
            from optimum.bettertransformer import BetterTransformer
            model = BetterTransformer.transform(model)
        except ImportError:
            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

    # Prepare the model for int8 training if quantization is enabled
    if train_config.quantization:
        model = prepare_model_for_int8_training(model)

    # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
    if train_config.enable_fsdp and fsdp_config.pure_bf16:
        model.to(torch.bfloat16)

    if train_config.use_peft:
        peft_config = generate_peft_config(train_config, kwargs)
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()

    #setting up FSDP if enable_fsdp is enabled
    if train_config.enable_fsdp:
        if not train_config.use_peft and train_config.freeze_layers:

            freeze_transformer_layers(train_config.num_freeze_layers)

        mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)

        model = FSDP(
            model,
            auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
            cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
            mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
            sharding_strategy=fsdp_config.sharding_strategy,
            device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
            limit_all_gathers=True,
            sync_module_states=train_config.low_cpu_fsdp,
            param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
            if train_config.low_cpu_fsdp and rank != 0 else None,
        )
        if fsdp_config.fsdp_activation_checkpointing:
            apply_fsdp_checkpointing(model)
    elif not train_config.quantization and not train_config.enable_fsdp:
        if is_xpu_available():
            model.to("xpu:0")
        else:
            model.to("cuda")

    ################ LTE Setting ################
    model.config.lte = train_config.lte

    if model_config.lte:
        print('model_config.moe_type', model_config.moe_type)
        model.config.lte = model_config.lte
        model.config.hard = model_config.hard
        model.config.moe_routing_mode = model_config.moe_routing_mode
        model.config.kmean_group = model_config.kmean_grouping

        print('model.config.hard', model.config.hard)
        print('model_config.kmean_grouping', model_config.kmean_grouping)
        print('model_config.moe_routing_mode', model_config.moe_routing_mode)

        if model_config.moe_type == 'block': #Construct the moe routers
            if model_config.kmean_grouping:
                model.model.add_moe(moe_type=model_config.moe_type, experts=model_config.moe_experts, split_path=model_config.kmean_grouping_path, k=model_config.moe_experts_selected, hard=model_config.hard)
            else:
                model.model.add_moe(moe_type=model_config.moe_type, experts=model_config.moe_experts, split_path=None, k=None, hard=model_config.hard)

        if not model_config.hard: # soft mode (phase 2)
            print('Load vanilla fine-tuned model!')
            print(f'Load model at {model_config.ckpt_path}')
            device = torch.device('cpu')
            if not model_config.use_pretrained:
                state_dict = torch.load(model_config.ckpt_path, map_location=device)
                model.load_state_dict(state_dict, strict=False)

        else: # hard mode (phase 3)
            print('Load soft LTE model!')
            print(f'Load model at {model_config.ckpt_path}')
            device = torch.device('cpu')
            state_dict = torch.load(model_config.ckpt_path, map_location=device)
            model.load_state_dict(state_dict, strict=False)

            model.model.set_moe_hard()
            model.model.reset_moe_sparsity_statistics()


    if train_config.enable_fsdp and fsdp_config.pure_bf16:
        model.to(torch.bfloat16)
    # for n, p in model.named_parameters():
    #     print(n)

    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.model_max_length = 2048

    dataset_config = generate_dataset_config(train_config, kwargs)

     # Load and preprocess the dataset for training and validation
    dataset_train = get_preprocessed_dataset(
        tokenizer,
        dataset_config,
        split="train",
    )

    # if not train_config.enable_fsdp or rank == 0:
    #     print(f"--> Training Set Length = {len(dataset_train)}")

    dataset_val = get_preprocessed_dataset(
        tokenizer,
        dataset_config,
        split="test",
    )

    if not train_config.enable_fsdp or rank == 0:
        print(f"--> Validation Set Length = {len(dataset_val)}")

    if train_config.batching_strategy == "packing":
        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)

    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
    train_dl_kwargs['batch_size'] = 1
    # import ipdb; ipdb.set_trace()

    # Create DataLoaders for the training and validation dataset
    train_dataloader = torch.utils.data.DataLoader(
        dataset_train,
        num_workers=train_config.num_workers_dataloader,
        pin_memory=True,
        **train_dl_kwargs,
    )

    # if (not train_config.enable_fsdp) or local_rank == 0:
    #     print(train_dl_kwargs)
    #     print(len(train_dataloader))

    eval_dataloader = None
    if train_config.run_validation:
        if train_config.batching_strategy == "packing":
            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)

        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")

        eval_dataloader = torch.utils.data.DataLoader(
            dataset_val,
            num_workers=train_config.num_workers_dataloader,
            pin_memory=True,
            **val_dl_kwargs,
        )

    ## adjust model parameters

    #hard mode
    if model_config.lte and model_config.hard:
        for name, param in model.named_parameters():
            if 'moe' in name:
                param.requires_grad = False

    #soft mode
    if model_config.lte and not model_config.hard:
        for name, param in model.named_parameters():
            if 'moe' in name:
                param.requires_grad = True
            if 'moe.experts_masks' in name:
                param.requires_grad = False

    llm_params = [p for n, p in model.named_parameters() if (not ('moe' in n))]
    moe_params = [p for n, p in model.named_parameters() if ('moe' in n) and p.requires_grad]

    print('moe_params', len(moe_params))
    print('model.device', model.device)

    for n, p in  model.named_parameters():
        if 'moe' in n:
            p.data = p.data.to(torch.bfloat16)

    for n, p in  model.named_parameters():
        if 'moe' in n:
            p.data = p.data.to(model.device)

    if rank == 0:
        for n, p in model.named_parameters():
            if p.requires_grad:
                print(n)

    if model_config.hard:
        moe_lr = 0
    else:
        moe_lr = 0.01

    if train_config.use_peft:
        model.print_trainable_parameters()

    if (not train_config.enable_fsdp) or local_rank == 0:
        print('moe_lr', moe_lr)

    optimizer_grouped_parameters = [
        {'params': llm_params, "lr": train_config.lr, 'weight_decay': train_config.weight_decay},
        {"params": moe_params, "lr": moe_lr, "weight_decay": 0.1}
    ]

    # Initialize the optimizer and learning rate scheduler
    print('fsdp_config.optimizer', fsdp_config.optimizer)
    if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
        optimizer = AnyPrecisionAdamW(
            model.parameters(),
            lr=train_config.lr,
            momentum_dtype=torch.bfloat16,
            variance_dtype=torch.bfloat16,
            use_kahan_summation=False,
            weight_decay=train_config.weight_decay,
        )
    else:
        # optimizer = optim.AdamW(
        #     model.parameters(),
        #     lr=train_config.lr,
        #     weight_decay=train_config.weight_decay,
        # )
        optimizer = optim.AdamW(
            optimizer_grouped_parameters
        )
    scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)

    if train_config.eval_mode:
        if model_config.keep_activation_output:
            eval_dataloader = train_dataloader
        eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
        exit()

    # Start the training process
    results = train(
        model,
        train_dataloader,
        eval_dataloader,
        tokenizer,
        optimizer,
        scheduler,
        train_config.gradient_accumulation_steps,
        train_config,
        model_config,
        fsdp_config if train_config.enable_fsdp else None,
        local_rank if train_config.enable_fsdp else None,
        rank if train_config.enable_fsdp else None,
    )
    if not train_config.enable_fsdp or rank==0:
        [print(f'Key: {k}, Value: {v}') for k, v in results.items()]

    if True:
        # use a barrier to make sure training is done on all ranks
        dist.barrier()
        if not os.path.exists(train_config.output_dir):
            os.makedirs(train_config.output_dir, exist_ok=True)
        saving_path = os.path.join(train_config.output_dir, 'last-ckpt.pt')
        states = model.state_dict()
        if rank == 0:
            print(f'Saving ckpt at {saving_path}')
            torch.save(states, saving_path)

if __name__ == "__main__":
    fire.Fire(main)
