import json
import logging
import os
from logging import Logger
import re
import sys
from typing import Dict, Iterator, List, Optional
import datetime

import torch
import transformers

from modeling_llama2_noaffine_inner_ema_step_obs_stem_max3_3_actha2all_4000nas2_out_simple_hesspro import LlamaForCausalLM

from utils.pretrain_trainer3 import PretrainStepTrainer
from utils.process_args import process_args
from torch.utils.data import Dataset, DataLoader
from transformers import AutoConfig, default_data_collator

from transformers import set_seed
from datasets import load_dataset, ClassLabel, concatenate_datasets, load_from_disk

from torch import distributed as dist

from torch.utils.data.distributed import DistributedSampler

# Define a utility method for setting the logging parameters of a logger
def get_logger(logger_name: Optional[str]) -> logging.Logger:
    # Get the logger with the specified name
    logger = logging.getLogger(logger_name)

    # Set the logging level of the logger to INFO
    logger.setLevel(logging.INFO)

    # Define a formatter for the log messages
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )

    # Create a console handler for outputting log messages to the console
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)

    # Add the console handler to the logger
    logger.addHandler(console_handler)

    return logger


log: Logger = get_logger("EfficientLLM")

def train() -> None:
    dist.init_process_group(
        backend="cpu:gloo,cuda:nccl", timeout=datetime.timedelta(hours=8)
    )
    model_args, data_args, training_args = process_args()
    set_seed(42)

    model_args.input_model_filename = "path"
    config = AutoConfig.from_pretrained(model_args.input_model_filename)
    config.m_start = 500  # 250
    config.step_quota = 1  # 13
    config.sub = 1

    training_args.report_to = ["wandb"]

    
    model = LlamaForCausalLM.from_pretrained(model_args.input_model_filename, config=config, attn_implementation="flash_attention_2")

    log.info(
        "model size is "
        + str(sum(param.numel() for param in model.model.parameters()) / 1024 / 1024)
    )
    
    log.info("Start to load tokenizer...")
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path=model_args.input_model_filename,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
    )
    log.info("Complete tokenizer loading...")

    ################################################
    seed = 42

    web = load_dataset('json', data_files=['path'], cache_dir="path")["train"]
    
    all_data_num = int(len(web)/0.7)

    aws = load_from_disk("path")
    aws_data_num = int(all_data_num * 0.08)
    aws_times = aws_data_num // len(aws) + 1
    aws = concatenate_datasets([aws] * aws_times).select(range(int(aws_data_num)))

    cos = load_from_disk("path")
    cos_data_num = int(all_data_num * 0.15)
    cos_times = cos_data_num // len(cos) + 1
    cos = concatenate_datasets([cos] * cos_times).select(range(int(cos_data_num)))

    math = load_from_disk("path")
    math_data_num = int(all_data_num * 0.055)
    math_times = math_data_num // len(math) + 1
    math = concatenate_datasets([math] * math_times).select(range(int(math_data_num)))

    cat = concatenate_datasets([aws, cos, web, math])
    cat = cat.shuffle(seed=seed)

    print("### token number:", len(cat)*2048 // 1024 //1024 //1024)
    ################################################

    def custom_collate_fn(batch):
        input_ids = [torch.tensor(item['token_ids']) for item in batch]
        return dict(input_ids=torch.stack(input_ids), labels=torch.stack(input_ids))

    def get_local_rank() -> int:
        if os.environ.get("LOCAL_RANK"):
            return int(os.environ["LOCAL_RANK"])
        else:
            logging.warning(
                "LOCAL_RANK from os.environ is None, fall back to get rank from torch distributed"
            )
            return torch.distributed.get_rank()
    def get_world_size() -> int:
        if os.environ.get("WORLD_SIZE"):
            return int(os.environ["WORLD_SIZE"])

        else:
            logging.warning(
                "WORLD_SIZE from os.environ is None, fall back to get rank from torch distributed"
            )
            return torch.distributed.get_world_size()

    sampler = DistributedSampler(cat, num_replicas=get_world_size(), rank=torch.distributed.get_rank())

    custum_dataloader = DataLoader(cat, batch_size=training_args.per_device_train_batch_size, collate_fn=custom_collate_fn, sampler=sampler)

    local_rank = get_local_rank()

    trainer = PretrainStepTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=custum_dataloader,
        eval_dataset=None,
        data_collator=default_data_collator,
    )
    torch.distributed.barrier(device_ids=[local_rank])

    if training_args.do_train:
        _ = trainer.train()
        trainer.save_state()

    torch.distributed.barrier(device_ids=[local_rank])


if __name__ == "__main__":
    train()
