import logging
import os
import sys

from transformers import AutoConfig, AutoTokenizer
from transformers import (
    HfArgumentParser,
    set_seed,
)
from transformers.integrations import TrainerCallback, TensorBoardCallback

from arguments import ModelArguments, DataArguments
from arguments import DenseTrainingArguments as TrainingArguments
import trainers
import networks
import dataloaders


logger = logging.getLogger(__name__)

       
class MyStopTrainStepCallback(TrainerCallback):
    "A callback that prints a message at the end of training step"

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step == args.early_stop_step:
            logger.info("End training at step: %d", state.global_step)
            control.should_training_stop = True
            
        return control
    
class MyStopTrainEpochCallback(TrainerCallback):
    "A callback that prints a message at the end of training epoch"

    def on_step_end(self, args, state, control, **kwargs):
        if state.epoch == args.early_stop_epoch:
            logger.info("End training at epoch: %d", state.epoch)
            control.should_training_stop = True
            
        return control
    

def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
        model_args: ModelArguments
        data_args: DataArguments
        training_args: TrainingArguments

    if (
            os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir)
            and training_args.do_train
            and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) \
            already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)
    logger.info("MODEL parameters %s", model_args)

    set_seed(training_args.seed)

    num_labels = 1
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        cache_dir=model_args.cache_dir,
    )
    

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=False,
    )
    
    
    ## Model
    model = networks.get_network(
        model_args,
        data_args,
        training_args,
        config=config,
        tokenizer=tokenizer,
        cache_dir=model_args.cache_dir,
        do_train=True,
    )
    
    tot_clm_size = 0
    tot_lmort_params = 0
    if model_args.fix_gpt:
        tot_param_num = 0
        valid_param_num = 0
        for name, param in model.named_parameters():
            tot_param_num += 1
            # if "lm_q" in name or "lm_p" in name:
            if "clm" in name:
                param.requires_grad = False
                
                if len(param.size()) == 1:
                    tot_clm_size += param.size(0)
                    
                elif len(param.size()) == 2:
                    tot_clm_size += param.size(0) * param.size(1)
                    
                else:
                    break
                    print("errorrrrrrrr ....")
                    
                    
            if param.requires_grad:
                print(name, param.size())
                
                if len(param.size()) == 1:
                    tot_lmort_params += param.size(0)
                    
                elif len(param.size()) == 2:
                    tot_lmort_params += param.size(0) * param.size(1)
                    
                else:
                    break
                    print("errorrrrrrrr ....")
                    
                valid_param_num += 1
                
        logger.info("tune params %.1f"%tot_lmort_params)
        logger.info("Fixed GPT params %.1f"%tot_clm_size)
        logger.info("Fix GPT, tune param ratio: %.1f"%((tot_lmort_params / tot_clm_size) * 100))

        
    
#     if model_args.fix_gpt:
#         tot_param_num = 0
#         valid_param_num = 0
#         for name, param in model.named_parameters():
#             tot_param_num += 1
#             # if "lm_q" in name or "lm_p" in name:
#             if "clm" in name:
#                 param.requires_grad = False
#             if param.requires_grad:
#                 valid_param_num += 1
#                 print(name, param.size(), param.requires_grad)
#         logger.info("Fix GPT, tune param ratio: %.1f"%((valid_param_num / tot_param_num) * 100))

    
    ## Train dataset and batchfy
    train_dataset, eval_dataset, QPCollator = dataloaders.get_train_dataset(
        tokenizer=tokenizer, 
        data_args=data_args,
    )
    
    ## early-stop or tensorboard
    callbacks = []
    if training_args.early_stop_step > 0:
        logger.info("Setting early stop step at: %d", training_args.early_stop_step)
        callbacks.append(MyStopTrainStepCallback)
    if training_args.early_stop_epoch > 0:
        logger.info("Setting early stop epoch at: %d", training_args.early_stop_epoch)
        callbacks.append(MyStopTrainEpochCallback)
        
    if training_args.tensorboard:
        logger.info("Setting Tensorboard ...")
        callbacks.append(TensorBoardCallback())
    
    ## training func
    trainer = trainers.get_trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=QPCollator(
            tokenizer,
            max_p_len=data_args.p_max_len,
            max_q_len=data_args.q_max_len
        ),
        callbacks=callbacks,
        delta_model=delta_model if model_args.param_efficient else None
    )
            
    train_dataset.trainer = trainer
    
    if training_args.resume_from_checkpoint is not None:
        trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)  # TODO: resume training
    else:
        trainer.train()

    trainer.save_model()

    if trainer.is_world_process_zero():
        tokenizer.save_pretrained(training_args.output_dir)
    


if __name__ == "__main__":
    main()
