# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes and functions related to finetune a text generation model with LoRA and reinforcement learning.
Modified from https://github.com/lqtrung1998/mwp_ReFT/
"""
import os
import sys

if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

import copy
from dataclasses import dataclass, field
from functools import partial
import json
import logging
from typing import Optional, Union, List, Dict
import torch
import warnings

import transformers
from transformers import (
    HfArgumentParser, 
    GenerationConfig,
    set_seed,
)
from transformers.trainer_utils import is_main_process, PREFIX_CHECKPOINT_DIR

from examples.run_lora import (
    ModelArguments, 
    ExtendedDataArguments, 
    ExtendedTrainingArguments, 
    GenerationArguments, 
    ImageProcessingArguments,
    get_tokenizer,
    get_image_processor,
    get_data,
    load_model,
    HaltTrainingCallback,
    PeftSaveCallback,
    Seq2SeqMetricsOnSeqIDs, 
    Seq2SeqMetricsOnGenerationSeqIDs,
    RenameCKPTFiles,
    CKPT_FOLDER,
    DTYPE_CLASS,
)
from utils.rl_utils import PRESET_REWARD_FUNCS

logger = logging.getLogger(__name__)


RL_ALGORITHMS = ['grpo']

@dataclass
class ExtendedTRLTrainingArguments(ExtendedTrainingArguments):
    rl_algorithm: str = field(
        default='grpo',
        metadata={"help": f"The optimization algorithm to use. Currently support: {RL_ALGORITHMS}."},
    )
    reward_func: str = field(
        default="['length_scaled_math_accuracy','think_ans_format']",
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
    )
    reward_weights: str = field(
        default="None",
        metadata={"help": "List of reward functions weights"},
    )
    
    # RL optimization arguments
    kl_coef: float = field(
        default=0.05,
        metadata={"help": "KL coefficient."},
    )
    num_iterations: int = field(
        default=1,
        metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
    )
    epsilon: float = field(
        default=0.2,
        metadata={"help": "Epsilon value for clipping."},
    )

    # Generation arguments
    num_generations: int = field(
        default=8,
        metadata={"help": "Number of generations to sample."},
    )
    target_generations: int = field(
        default=8,
        metadata={"help": "Number of target generations to sample."},
    )
    min_correct_generations: int = field(
        default=4,
        metadata={"help": "Minimum number of correct generations"},
    )
    min_incorrect_generations: int = field(
        default=4,
        metadata={"help": "Minimum number of incorrect generations"},
    )
    p_low: float = field(
        default=0.25,
        metadata={"help": "Lower probability threshold"},
    )
    p_high: float = field(
        default=0.5,
        metadata={"help": "Higher probability threshold"},
    )

    # DeepSpeed and VLLM settings
    ds3_gather_for_generation: bool = field(
        default=True,
        metadata={"help": "DeepSpeed ZeRO-3 gathering setting"},
    )
    use_vllm: bool = field(
        default=False,
        metadata={"help": "Whether to use vLLM for generating completions."},
    )
    vllm_device: str = field(
        default="auto",
        metadata={"help": "Device for vLLM generation"},
    )
    vllm_gpu_memory_utilization: float = field(
        default=0.7,
        metadata={"help": "GPU memory utilization ratio for vLLM"},
    )
    vllm_dtype: str = field(
        default="auto",
        metadata={"help": "Data type for vLLM generation"},
    )
    vllm_max_model_len: int = field(
        default=None,
        metadata={"help": "Max model length for vLLM"},
    )
    
    # Reward function specific arguments
    LSMA_min_value_wrong: float = field(
        default=-1.0, 
        metadata={"help": "Minimum reward when completion is wrong."}
    )
    LSMA_max_value_wrong: float = field(
        default=-0.5, 
        metadata={"help": "Maximum reward when completion is wrong."}
    )
    LSMA_min_value_correct: float = field(
        default=0.5, 
        metadata={"help": "Minimum reward when completion is correct."}
    )
    LSMA_max_value_correct: float = field(
        default=1.0, 
        metadata={"help": "Maximum reward when completion is correct."}
    )
    LSMA_max_seq_length: float = field(
        default=3072, 
        metadata={"help": "Maximum sequence length for reward scaling."}
    )
    
    # Other arguments
    log_completions: bool = field(
        default=False,
        metadata={"help": "Whether to log the completions during training."},
    )

    def __post_init__(self):
        super().__post_init__()
        self.rl_algorithm = self.rl_algorithm.lower()
        assert self.rl_algorithm in RL_ALGORITHMS, NotImplementedError(
            f"The following rl_algorithm is not implemented: {self.rl_algorithm}"
        )
        if isinstance(self.reward_func, str):
            if self.reward_func.startswith("[") and self.reward_func.endswith("]"):  # lazy check
                self.reward_func = eval(self.reward_func)
            else:
                self.reward_func = [self.reward_func]
        assert isinstance(self.reward_func, list), TypeError(f"Expect self.reward_func to be a list; got {self.reward_func}")
        assert all(reward_func in PRESET_REWARD_FUNCS for reward_func in self.reward_func or os.path.isdir(reward_func)), NotImplementedError(
            "The following reward funcs are not implemented: {}".format([reward_func for reward_func in self.reward_func if reward_func not in PRESET_REWARD_FUNCS])
        )
        if self.reward_weights in {"None", ""}:
            self.reward_weights = [1.0] * len(self.reward_func)
        else:
            if self.reward_weights.startswith("[") and self.reward_weights.endswith("]"):  # lazy check
                self.reward_weights = eval(self.reward_weights)
            else:
                raise ValueError(f"Expect self.reward_weights to be a list of float; got {self.reward_weights}")
        # others
        self.model_init_kwargs = {}



def parse_args(args):
    # We keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ModelArguments, ExtendedDataArguments, ExtendedTRLTrainingArguments, GenerationArguments, ImageProcessingArguments))
        
    if len(args) == 1 and args[0].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args, generation_args, image_processing_args = parser.parse_json_file(json_file=os.path.abspath(args))
    else:
        model_args, data_args, training_args, generation_args, image_processing_args = parser.parse_args_into_dataclasses(args)
    
    if (
        os.path.exists(training_args.output_dir)
        and any(_dir.startswith(PREFIX_CHECKPOINT_DIR) for _dir in os.listdir(training_args.output_dir))
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
        )

    # check torch_dtype and fp16/bf16 options
    if model_args.torch_dtype == "float16":
        assert training_args.fp16 and (not training_args.bf16), ValueError(
            "When torch_dtype is float16, fp16 must be True and bf16 must be False.")
    elif model_args.torch_dtype == "bfloat16":
        assert (not training_args.fp16) and training_args.bf16, ValueError(
            "When torch_dtype is bfloat16, fp16 must be False and bf16 must be True.")
        warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")

    return model_args, data_args, training_args, generation_args, image_processing_args


def get_generation_config(data_args, generation_args, tokenizer):
    # set up generation_config
    generation_config = GenerationConfig(**copy.deepcopy(generation_args.__dict__))
    # configure generation_config
    generation_config.pad_token_id = tokenizer.pad_token_id
    generation_config.bos_token_id = tokenizer.bos_token_id
    generation_config.eos_token_id = tokenizer.eos_token_id
    # set up role_tags and role_map
    generation_config.role_tags = data_args.role_tags
    generation_config.role_map = data_args.role_map
    
    if generation_config.ensemble_method == "reward_model":
        generation_config.return_dict_in_generate = True
        generation_config.output_scores = True
        
    if generation_config.output_answer_probs == True:
        generation_config.return_dict_in_generate = True
        generation_config.output_scores = True
        # post_process_fn = GeneratorRetainAnswerPorbsProcessor(tokenizer)
    if data_args.prompt_templates in ["granite_guardian_qa"]:
        generation_config.output_scores = True
    logger.info(generation_config)

    return generation_config



def get_reward_funcs(
    training_args, 
    base_model_prefix=None, 
    base_model_path=None, 
    tokenizer=None,
    model_args=None, 
    data_args=None, 
    _print=None
    ):
    reward_funcs = []
    
    if _print:
        _print(f"Raw reward_func: {training_args.reward_func}")
        _print(f"Type of reward_func: {type(training_args.reward_func)}")

    try:
        if isinstance(training_args.reward_func, str):
            if training_args.reward_func.startswith('[') and training_args.reward_func.endswith(']'):
                import ast
                reward_func_list = ast.literal_eval(training_args.reward_func)
            else:
                reward_func_list = [training_args.reward_func]
        else:
            reward_func_list = training_args.reward_func

        if _print:
            _print(f"Parsed reward_func_list: {reward_func_list}")
    except Exception as e:
        raise ValueError(f"Failed to parse reward_func: {training_args.reward_func}. Error: {e}")

    for reward_func_str in reward_func_list:
        reward_func_str = reward_func_str.strip()  
        if _print:
            _print(f"Processing reward function: {reward_func_str}")
            
        if reward_func_str in PRESET_REWARD_FUNCS:
            reward_func = PRESET_REWARD_FUNCS[reward_func_str]
            if reward_func_str == "adaptive_reasoning_control":
                reward_func = partial(
                    reward_func, 
                    tokenizer=tokenizer,
                    p_threshold=training_args.p_high,
                    p_low=training_args.p_low,
                )
            elif reward_func_str == "math_accuracy":
                reward_func = partial(
                    reward_func,
                    tokenizer=tokenizer,
                )
            elif reward_func_str in ["process_switch", "process_depth", "process_output", "process_sequential"]:
                reward_func = partial(
                    reward_func,
                    n_generations=training_args.target_generations,
                )
            reward_funcs.append(reward_func)
            if _print:
                _print(f"Added reward function: {reward_func_str}")
        elif os.path.isdir(reward_func_str):
            reward_model = load_model(base_model_prefix, base_model_path, tokenizer, training_args, model_args, data_args, reward_func_str, _print=_print)
            reward_funcs.append(reward_model)
            if _print:
                _print(f"Added reward model from path: {reward_func_str}")
        else:
            if _print:
                _print(f"Invalid reward function: {reward_func_str}")
                _print(f"Available reward functions: {list(PRESET_REWARD_FUNCS.keys())}")
            raise NotImplementedError(
                f"Reward function '{reward_func_str}' is not implemented. "
                f"Available options are: {list(PRESET_REWARD_FUNCS.keys())}"
            )

    if not reward_funcs:
        raise ValueError("No valid reward functions were created")

    if _print:
        _print(f"Created {len(reward_funcs)} reward functions")

    return reward_funcs


def get_trainer(model, reward_funcs, training_args, model_args, data_args, generation_config, train_data, validation_data, tokenizer):
    callbacks = []
    if training_args.halt_step > -1:
        halt_callback = HaltTrainingCallback(halt_step=training_args.halt_step)
        callbacks.append(halt_callback)
    if training_args.save_strategy in ['steps', 'epoch']:
        save_callback = PeftSaveCallback(model_args)
        callbacks.append(save_callback)

    # Setup evaluation
    if validation_data and training_args.do_eval:
        if training_args.evaluation_method == 'autoregressive':
            compute_metrics = Seq2SeqMetricsOnGenerationSeqIDs(
                training_args.validation_metrics, tokenizer=tokenizer, padding_side=padding_side)
        else:
            compute_metrics = Seq2SeqMetricsOnSeqIDs(
                training_args.validation_metrics, tokenizer=tokenizer, padding_side=padding_side)
    else:
        compute_metrics = None

    if training_args.rl_algorithm == 'grpo':
        from utils.trainer_utils import MyGRPOTrainer as Trainer
    else:
        raise NotImplementedError(
            f"The following rl_algorithm is not implemented: {training_args.rl_algorithm}"
        )

    generation_config.max_seq_length = data_args.max_seq_length
    training_args.model_init_kwargs = {
        "load_in_8bit": model_args.load_in_8bit,
        "torch_dtype": DTYPE_CLASS[model_args.torch_dtype],
        "device_map": None,
        "config": model.config,
        "use_flash_attention_2": model_args.use_flash_attn,
    }
    trainer = Trainer(
        model=model,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=train_data if training_args.do_train else None,
        eval_dataset=validation_data if training_args.do_eval else None,
        # tokenizer=tokenizer,  # deprecated, use processing_class
        processing_class=tokenizer,
        callbacks=callbacks,
        compute_metrics=compute_metrics,
        evaluation_method=training_args.evaluation_method,
        generation_config=generation_config
    )

    return trainer


def do_training(args):
    """ Step 1. Prepare arguments and logging """
    model_args, data_args, training_args, generation_args, image_processing_args = parse_args(args)
    
    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
    )

    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()

    # print 
    logger.info('Available cuda devices: {}'.format(torch.cuda.device_count()))
    logger.info('Pytorch version: {}; transformers version {}'.format(torch.__version__, transformers.__version__))
    logger.info("\nData related arguments:\n %s", data_args)
    logger.info("\nModel arguments:\n %s", model_args)
    logger.info("\nTraining/evaluation parameters:\n %s", training_args)
    logger.info("\nTraining generation parameters:\n %s", generation_args)
    
    """ Step 2. Configure tokenizer, dataset, model and training """
    base_model_path = os.path.join(CKPT_FOLDER, model_args.base_model)
    base_model_prefix = model_args.base_model.rsplit('-', 1)[0]
    # Set seed before initializing model.
    set_seed(training_args.seed)
    # set up tokenizer
    tokenizer = get_tokenizer(base_model_prefix, base_model_path)
    generation_config = get_generation_config(data_args, generation_args, tokenizer)
    # image processor
    image_processor = get_image_processor(base_model_prefix, image_processing_args)
    # load datasets
    train_data, validation_data = get_data(data_args, training_args, tokenizer, image_processor)
    logger.info("========================= Datasets loaded. =========================")
    # configure model_config, load model, load lora, etc
    model_path = (
        model_args.model_name_or_path
        if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
        else None
    )
    model = load_model(base_model_prefix, base_model_path, tokenizer, training_args, model_args, data_args, model_path=model_path, _print=logger.info)
    logger.info("========================= Model configured. =========================")
    # configure reward functions
    reward_funcs = get_reward_funcs(training_args, base_model_prefix, base_model_path, tokenizer, model_args, data_args, _print=logger.info)
    logger.info("========================= Reward funcs configured. =========================")

    trainer = get_trainer(model, reward_funcs, training_args, model_args, data_args, generation_config, train_data, validation_data, tokenizer)
    logger.info("========================= Trainer configured. =========================")

    """ Step 3. Train the model """
    if training_args.do_train:
        # rename trainer_state.json, so to avoid trainer.train loading trainer_state.json to check epochs_trained
        ckpt_file_handler = RenameCKPTFiles(model_path)
        if (
            model_path is not None
            and training_args.ignore_trainer_state
        ):
            ckpt_file_handler.rename_files()

        train_result = trainer.train(resume_from_checkpoint=None)  # TODO previously set resume_from_checkpoint=model_path; will cause error when trainer_state shall be ignored
        model.save_pretrained(training_args.output_dir)

        # rename trainer_state.json back
        if (
            model_path is not None
            and training_args.ignore_trainer_state
        ):
            ckpt_file_handler.restore_file_names()

        output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
        if is_main_process(training_args.local_rank):
            with open(output_train_file, "w") as writer:
                # logger.info("***** Train results *****")
                for key, value in sorted(train_result.metrics.items()):
                    # logger.info(f"  {key} = {value}")
                    writer.write(f"{key} = {value}\n")

            # Need to save the state, since Trainer.save_model saves only the tokenizer with the model
            trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
        
        logger.info("========================= Training completed. =========================")
    

if __name__ == "__main__":
    do_training(sys.argv[1:])