import logging
import os
import sys
import json
import yaml

import wandb
import torch
import torch.distributed as dist

import transformers
from transformers import AutoConfig, AutoTokenizer 
from transformers.integrations import deepspeed
from transformers import (
    HfArgumentParser,
    set_seed,
)
from transformers.trainer_utils import is_main_process
from transformers import (
    TrainerCallback,
    TrainerState,
    TrainerControl
)

from peft import LoraConfig

from arguments import (
    ModelArguments,
    DataArguments,
    EmbeddingTrainingArguments as TrainingArguments,
    LoraArguments
)

from data import (
    MultiDatasetMNKD,
    TripleCollatorMNKD,
    DynamicBatchSampler
)
from model.model import AutoModelForRanking

from trainer import (
    EmbeddingTrainer as Trainer,
    GCTrainer,
)
from deepspeed import zero
import torch.distributed as dist
from contextlib import contextmanager

logger = logging.getLogger(__name__)


class ShuffleCallback(TrainerCallback):
    def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        """
        Event called at the beginning of an epoch.
        """
        train_dataloader = kwargs['train_dataloader']
        train_dataloader.dataset.shuffle_batch()

def maybe_zero_3(param):
    if hasattr(param, "ds_id"):
        assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
    else:
        param = param.detach().cpu().clone()
    return param

def get_peft_state_maybe_zero_3(named_params, bias):
    if bias == "none":
        to_return = {k: t for k, t in named_params if "lora_" in k}
    elif bias == "all":
        to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
    elif bias == "lora_only":
        to_return = {}
        maybe_lora_bias = {}
        lora_bias_names = set()
        for k, t in named_params:
            if "lora_" in k:
                to_return[k] = t
                bias_name = k.split("lora_")[0] + "bias"
                lora_bias_names.add(bias_name)
            elif "bias" in k:
                maybe_lora_bias[k] = t
        for k, t in maybe_lora_bias:
            if bias_name in lora_bias_names:
                to_return[bias_name] = t
    else:
        raise NotImplementedError
    to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
    return to_return

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"):
    """Collects the state dict and dump to disk."""
    # check if zero3 mode enabled
    if deepspeed.is_deepspeed_zero3_enabled():
        state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
    else:
        if trainer.args.use_lora:
            state_dict = get_peft_state_maybe_zero_3(
                trainer.model.named_parameters(), bias
            )
        else:
            state_dict = trainer.model.state_dict()
    if trainer.args.should_save and trainer.args.local_rank == 0:
        trainer._save(output_dir, state_dict=state_dict)

def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, LoraArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args, lora_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args, lora_args = parser.parse_args_into_dataclasses()
    
    model_args: ModelArguments
    data_args: DataArguments
    training_args: TrainingArguments
    lora_args: LoraArguments
    # training_args.remove_unused_columns = False

    if (
            os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir)
            and training_args.do_train
            and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    print('lora targe', lora_args.lora_target_modules)

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
    if training_args.local_rank in (0, -1):
        logger.info("Training/evaluation parameters %s", training_args)
        logger.info("Model parameters %s", model_args)
        logger.info("Data parameters %s", data_args)

    set_seed(training_args.seed)

    num_labels = 1
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        cache_dir=model_args.cache_dir,
        trust_remote_code=True
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        trust_remote_code=True,
        pad_token='<|endoftext|>'
    )
    tokenizer.padding_side = 'left'
        
    lora_config = None
    if training_args.use_lora:
        lora_config = LoraConfig(
            r=lora_args.lora_r,
            lora_alpha=lora_args.lora_alpha,
            target_modules=lora_args.lora_target_modules,
            lora_dropout=lora_args.lora_dropout,
            bias=lora_args.lora_bias,
            task_type="CAUSAL_LM",
            modules_to_save=None,  # This argument serves for adding new tokens.
            exclude_modules=lora_args.exclude_modules
        )

        
    model = AutoModelForRanking(
        model_args.model_name_or_path,
        train_args = training_args,
        data_args = data_args,
        temperature=training_args.temperature,
        pooling=model_args.pooling,
        normalize=model_args.normalize,
        cache_dir=model_args.cache_dir,
        trust_remote_code=True,
        attn_type=model_args.attn_type, 
        use_lora=training_args.use_lora,
        lora_config=lora_config,
        train_type=training_args.train_type,
    )

    if training_args.local_rank > 0:
        print("Waiting for main process to perform the mapping")
        torch.distributed.barrier()
    if training_args.local_rank == 0:
        print("Loading results from main process")
        torch.distributed.barrier()

    world_size = dist.get_world_size() if dist.is_initialized() else 1
    ft_data_configs = yaml.safe_load(open(data_args.finetune_data_config)) # List[Dict]
    length_config = None
    if data_args.length_config:
        with open(data_args.length_config) as f:
            length_config = json.load(f)
    with training_args.main_process_first(local=False, desc="loading dataset"):
        train_dataset = MultiDatasetMNKD(
            data_configs=ft_data_configs,
            length_config=length_config,
            default_batch_size=training_args.per_device_train_batch_size,
            neg_per_ins=data_args.neg_per_ins,
            instruction=data_args.instruction,
            doc_instruction=data_args.doc_instruction,
            num_gpu=training_args.world_size,
            boq_token=data_args.boq_token,
            bod_token=data_args.bod_token,
            random_neg=data_args.random_neg,
            default_max_length=data_args.max_len
        )
    sampler = DynamicBatchSampler(train_dataset, num_process=training_args.world_size, process_index=training_args.process_index)

    eod_token = data_args.eod_token
    if eod_token is None and 'qwen' in model_args.model_name_or_path.lower():
        eod_token = "\n<|im_start|>assistant\n"
    data_collator = TripleCollatorMNKD(
        tokenizer,
        max_length=data_args.max_len
    )

    trainer_cls = GCTrainer if training_args.grad_cache else Trainer
    print('training args', training_args)


    trainer = trainer_cls(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
        tokenizer=tokenizer,
        batch_sampler=sampler,
    )
    
    if training_args.continue_train:
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()  
    safe_save_model_for_hf_trainer(trainer, output_dir=training_args.output_dir)
    


if __name__ == "__main__":
    main()
