# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from dataclasses import dataclass
import json
from memgpt.trl.utils.utils_metrics import compute_loss_func, set_wandb, set_tokenizer, compute_metrics, set_use_special_dblookup_tokens, preprocess_logits_for_metrics
from memgpt.trl.utils.load_sft_dataset import prepare_pretrain_data
from memgpt.trl.utils.load_model import initialize_model_for_pretraining
from memgpt.trl.utils_pack_dataset import create_packed_dataset
from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)

from accelerate import Accelerator

import torch
import numpy as np
import random
import os
from accelerate import Accelerator


def set_random_seed(seed: int):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

def main(script_args, training_args, model_args, pretrain_args):
    set_random_seed(training_args.seed if hasattr(training_args, "seed") else 42)

    accelerator = Accelerator()
    if accelerator.is_main_process:
        set_wandb()

    print(f"use_special_dblookup_tokens: {pretrain_args.use_special_dblookup_tokens}")
    assert not pretrain_args.plain_baseline or not pretrain_args.use_special_dblookup_tokens, "Cannot have both plain baseline and dblookup tokens enabled"

    ################
    # Model init kwargs & Tokenizer
    ################
    if training_args.resume_from_checkpoint:
        model, tokenizer = initialize_model_for_pretraining(model_args, resume_from_checkpoint=training_args.resume_from_checkpoint, use_special_dblookup_tokens=pretrain_args.use_special_dblookup_tokens)
    else:
        model, tokenizer = initialize_model_for_pretraining(model_args, use_special_dblookup_tokens=pretrain_args.use_special_dblookup_tokens)

    ################
    # Dataset
    ################
    train_dataset, eval_dataset = prepare_pretrain_data(script_args, pretrain_args.use_special_dblookup_tokens, pretrain_args.plain_baseline)

    if accelerator.is_main_process:
        print(f"train_dataset[10]: {train_dataset[10][training_args.dataset_text_field]}")
        print(f"==== Training dataset ====")
        print(train_dataset)
        print(f"==== Eval dataset ====")
        print(eval_dataset)

    set_use_special_dblookup_tokens(pretrain_args.use_special_dblookup_tokens)
    set_tokenizer(tokenizer)

    training_args.remove_unused_columns=False
    training_args.max_seq_length=1024

    ################
    # Evaluation
    ################

    # if not pretrain_args.stupid_only:
    training_args.compute_loss_func=compute_loss_func # pretrain weighted loss


    ################
    # Training
    ################
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset if training_args.eval_strategy != "no" else None,
        processing_class=tokenizer,
        peft_config=get_peft_config(model_args),
        compute_metrics=compute_metrics,
        compute_loss_func=compute_loss_func,
        # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    eval_results = trainer.evaluate()
    print(json.dumps(eval_results, indent=4))
    import pdb; pdb.set_trace()

    if not pretrain_args.eval_only:
        if training_args.resume_from_checkpoint:
            trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
        else:
            trainer.train()
        
    eval_results = trainer.evaluate()
    
    print("==== Eval Results ====")
    print(json.dumps(eval_results, indent=4))

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)

@dataclass
class PretrainConfig:
    use_special_dblookup_tokens: bool = False
    plain_baseline: bool = False
    eval_only: bool = False
    stupid_only: bool = False

def make_parser(subparsers: argparse._SubParsersAction = None):
    dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, PretrainConfig)
    if subparsers is not None:
        parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
    else:
        parser = TrlParser(dataclass_types)
    return parser


if __name__ == "__main__":
    parser = make_parser()
    script_args, training_args, model_args, pretrain_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args, pretrain_args)