from __future__ import annotations

import dataclasses
import os
import sys
from dataclasses import dataclass, field
from typing import *

from transformers.hf_argparser import DataClassType, HfArgumentParser


@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """

    trainer_type: str = field(
        default="vanilla",
        metadata={"help": "The type of trainer to use. "},
    )
    ref_model: Optional[str] = field(
        default=None,
        metadata={
            "help": "The path to the reference model for fine-tuning. Leave blank for using the same model as the policy."
        },
    )
    train_path: str = field(
        default="",
        metadata={"help": "The path to the training data."},
    )
    eval_path: Optional[str] = field(
        default=None,
        metadata={
            "help": "The path to the evaluation data. Leave blank for not evaluating."
        },
    )
    split_eval_from_train: Optional[int] = field(
        default=None,
        metadata={
            "help": "The number of samples to split from the training data for evaluation. Leave blank for not evaluating."
        },
    )
    use_eos_padding: Optional[bool] = field(
        default=False,
        metadata={
            "help": "Whether to use end-of-sentence (EOS) padding for the input sequences. Set this according to your RM training."
        },
    )
    margin_scale: Optional[float] = field(
        default=1.0,
        metadata={
            "help": "The margin for scaling the difference between the chosen and rejected scores."
        },
    )
    soft_threshold: Optional[float] = field(
        default=0.1,
        metadata={
            "help": "The threshold of using soft dpo. If the reward gap is lower than threshold, use soft dpo"
        },
    )
    sanity_check: Optional[bool] = field(
        default=False, metadata={"help": "Only train on 1000 samples."}
    )
    max_training_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "The maximum number of training samples to use. Leave blank for using all available samples."
        },
    )
    label_type: Optional[str] = field(
        default=None,
        metadata={"help": "The type of oracle label to use."},
    )
    manual_seed: Optional[int] = field(
        default=0,
        metadata={"help": "The random seed to use."},
    )
    n_iter: Optional[int] = field(
        default=None,
        metadata={
            "help": "The n-th iteration of DPO training. Must be manually set."
        },
    )
    output_model_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name for the output dataset"
        },
    )


class H4ArgumentParser(HfArgumentParser):
    """
    Taken from huggingface/alignment-handbook
    """

    def parse_yaml_and_args(
        self, yaml_arg: str, other_args: Optional[List[str]] = None
    ) -> List[dataclass]:
        """
        Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.

        Args:
            yaml_arg (`str`):
                The path to the config file used
            other_args (`List[str]`, *optional`):
                A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].

        Returns:
            [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
        """
        arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))

        outputs = []
        # strip other args list into dict of key-value pairs
        other_args = {
            arg.split("=")[0].strip("-"): arg.split("=")[1]
            for arg in other_args
        }
        used_args = {}

        # overwrite the default/loaded value with the value provided to the command line
        # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
        for data_yaml, data_class in zip(arg_list, self.dataclass_types):
            keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
            inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
            for arg, val in other_args.items():
                # add only if in keys

                if arg in keys:
                    base_type = data_yaml.__dataclass_fields__[arg].type
                    inputs[arg] = val

                    # cast type for ints, floats (default to strings)
                    if base_type in [int, float]:
                        inputs[arg] = base_type(val)

                    if base_type == List[str]:
                        inputs[arg] = [str(v) for v in val.split(",")]

                    # bool of a non-empty string is True, so we manually check for bools
                    if base_type is bool:
                        if val in ["true", "True"]:
                            inputs[arg] = True
                        else:
                            inputs[arg] = False

                    # add to used-args so we can check if double add
                    if arg not in used_args:
                        used_args[arg] = val
                    else:
                        raise ValueError(
                            f"Duplicate argument provided: {arg}, may cause unexpected behavior"
                        )

            obj = data_class(**inputs)
            outputs.append(obj)

        return outputs

    def parse(self) -> DataClassType | Tuple[DataClassType]:
        if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
            # If we pass only one argument to the script and it's the path to a YAML file,
            # let's parse it to get our arguments.
            output = self.parse_yaml_file(os.path.abspath(sys.argv[1]))
        # parse command line args and yaml file
        elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
            output = self.parse_yaml_and_args(
                os.path.abspath(sys.argv[1]), sys.argv[2:]
            )
        # parse command line args only
        else:
            output = self.parse_args_into_dataclasses()

        if len(output) == 1:
            output = output[0]
        return output
