# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import json
import logging
import os
from logging import Logger
import re
import sys
from typing import Dict, Iterator, List, Optional
import datetime

import torch
import transformers

from inverse_llama2 import LlamaForCausalLM

from utils.pretrain_trainer_cooldown import PretrainTrainer
from utils.process_args import process_args
from torch.utils.data import Dataset, DataLoader
from transformers import AutoConfig, default_data_collator

from transformers import set_seed
from datasets import load_dataset, ClassLabel, concatenate_datasets, load_from_disk

from torch import distributed as dist

from torch.utils.data.distributed import DistributedSampler

# Define a utility method for setting the logging parameters of a logger
def get_logger(logger_name: Optional[str]) -> logging.Logger:
    # Get the logger with the specified name
    logger = logging.getLogger(logger_name)

    # Set the logging level of the logger to INFO
    logger.setLevel(logging.INFO)

    # Define a formatter for the log messages
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )

    # Create a console handler for outputting log messages to the console
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)

    # Add the console handler to the logger
    logger.addHandler(console_handler)

    return logger


log: Logger = get_logger("EfficientLLM")

def train() -> None:
    dist.init_process_group(backend="cpu:gloo,cuda:nccl", timeout=datetime.timedelta(hours=8))

    model_args, data_args, training_args = process_args()
    set_seed(42)
    model_args.input_model_filename = "path"
    config = AutoConfig.from_pretrained(model_args.input_model_filename)

    training_args.report_to = ["wandb"]
    
    model = LlamaForCausalLM.from_pretrained(model_args.input_model_filename, config=config, attn_implementation="flash_attention_2")
    

    log.info(
        "model size is "
        + str(sum(param.numel() for param in model.model.parameters()) / 1024 / 1024)
    )
    
    log.info("Start to load tokenizer...")
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path=model_args.input_model_filename,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
    )
    log.info("Complete tokenizer loading...")

    seed = 42

    web = load_dataset('json', data_files=['path'], cache_dir="path")["train"]
    
    all_data_num = int(len(web)/0.7)

    aws = load_from_disk("path")
    aws_data_num = int(all_data_num * 0.08)
    aws_times = aws_data_num // len(aws) + 1
    aws = concatenate_datasets([aws] * aws_times).select(range(int(aws_data_num)))

    cos = load_from_disk("path")
    cos_data_num = int(all_data_num * 0.15)
    cos_times = cos_data_num // len(cos) + 1
    cos = concatenate_datasets([cos] * cos_times).select(range(int(cos_data_num)))

    math = load_from_disk("path")
    math_data_num = int(all_data_num * 0.055)
    math_times = math_data_num // len(math) + 1
    math = concatenate_datasets([math] * math_times).select(range(int(math_data_num)))

    cat = concatenate_datasets([aws, cos, web, math])
    cat = cat.shuffle(seed=seed)

    print("### token numbers", len(cat)*2048 // 1024 //1024 //1024)

    def custom_collate_fn(batch):
        input_ids = [torch.tensor(item['token_ids']) for item in batch]
        return dict(input_ids=torch.stack(input_ids), labels=torch.stack(input_ids))

    def get_local_rank() -> int:
        if os.environ.get("LOCAL_RANK"):
            return int(os.environ["LOCAL_RANK"])
        else:
            logging.warning(
                "LOCAL_RANK from os.environ is None, fall back to get rank from torch distributed"
            )
            return torch.distributed.get_rank()
    def get_world_size() -> int:
        if os.environ.get("WORLD_SIZE"):
            return int(os.environ["WORLD_SIZE"])

        else:
            logging.warning(
                "WORLD_SIZE from os.environ is None, fall back to get rank from torch distributed"
            )
            return torch.distributed.get_world_size()

    sampler = DistributedSampler(cat, num_replicas=get_world_size(), rank=torch.distributed.get_rank())

    custum_dataloader = DataLoader(cat, batch_size=training_args.per_device_train_batch_size, collate_fn=custom_collate_fn, sampler=sampler)
    local_rank = get_local_rank()

    trainer = PretrainTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=custum_dataloader,
        eval_dataset=None,
        data_collator=default_data_collator,
    )
    torch.distributed.barrier(device_ids=[local_rank])

    if training_args.do_train:
        _ = trainer.train()
        trainer.save_state()

    torch.distributed.barrier(device_ids=[local_rank])


if __name__ == "__main__":
    train()
