
import os
import sys
import dataclasses
from dataclasses import dataclass, field
from typing import Dict, Optional, List
import random

import torch
from copy import deepcopy
from accelerate import Accelerator
from datasets import Dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
from trl import DPOConfig, DPOTrainer

# import my function
from src.trainer.dpop_trainer import DPOPTrainer
from src.trainer.udpo_trainer import UDPOTrainer
from src.utils.utils import load_json
from src.utils.prompt import ORG_PROPMT, SQLCODER_PROMPT, MATH_PROMPT, BIRD_PROMPT, BIRD_SYS_PROMPT, SQL_SYS_PROMPT, MISTRAL_BIRD_PROMPT, APPS_PLUS_PROMPT
from src.utils.parser import H4ArgumentParser

import os

prompt_dict = {
    'org': ORG_PROPMT,
    'sqlprompt': SQLCODER_PROMPT,
    'math': MATH_PROMPT,
    'bird': BIRD_PROMPT,
    'bird_sys': BIRD_SYS_PROMPT,
    'sql_sys': SQL_SYS_PROMPT,
    "mistral_bird": MISTRAL_BIRD_PROMPT,
    "apps_plus": APPS_PLUS_PROMPT
}

# Define and parse arguments.
@dataclass
class ScriptArguments(DPOConfig):
    """
    The arguments for the DPO training script.
    """

    # data parameters
    beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})

    # training parameters
    model_name_or_path: Optional[str] = field(
        default="../hf_models/Meta-Llama-3-8B-Instruct",
        metadata={"help": "the location of the SFT model name or path"},
    )
    data_path: Optional[str] = field(
        default="data/merge_0611_dpo.json",
        metadata={"help": "the location of the data path"},
    )
    eval_ratio: Optional[float] = field(default=0.0, metadata={"help": "the eval ratio"})
    prompt_type: Optional[str] = field(default="org", metadata={"help": "org or sqlprompt"})
    learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
    lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
    weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
    optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})

    per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
    per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
    gradient_accumulation_steps: Optional[int] = field(
        default=4, metadata={"help": "the number of gradient accumulation steps"}
    )
    gradient_checkpointing: Optional[bool] = field(
        default=True, metadata={"help": "whether to use gradient checkpointing"}
    )

    gradient_checkpointing_use_reentrant: Optional[bool] = field(
        default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
    )

    lora_alpha: Optional[float] = field(default=None, metadata={"help": "the lora alpha parameter"})
    lora_dropout: Optional[float] = field(default=None, metadata={"help": "the lora dropout parameter"})
    lora_r: Optional[int] = field(default=None, metadata={"help": "the lora r parameter"})

    max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
    max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
    logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
    save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
    eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})

    output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
    torch_dtype: Optional[str] = field(
        default="float16", metadata={"help": "torch_dtype[float16, bfloat16, float] for loading."}
    )

    # instrumentation
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
    report_to: Optional[str] = field(
        default="wandb",
        metadata={
            "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
            '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
            'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
        },
    )
    # debug argument for distributed training
    ignore_bias_buffers: Optional[bool] = field(
        default=False,
        metadata={
            "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
            "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
        },
    )
    seed: Optional[int] = field(
        default=42, metadata={"help": "Random seed that will be set at the beginning of training."}
    )

    # add:
    dpop_lambda: Optional[int] = field(default=50, metadata={"help": "dpop lambda"})
    dpo_type: Optional[str] = field(default="dpo", metadata={"help": "dpo type, dpo/dpop"})
    udpo_windows: Optional[int] = field(default=5, metadata={"help": "udpo_windows"})
    conf_threshold: Optional[float] = field(default=0.3, metadata={"help": "conf_threshold"})
    iterative: Optional[bool] = field(
        default=False,
        metadata={
            "help": "whether to use iterative training"
        },
    )
    ref_model_path: Optional[str] = field(
        default=None,
        metadata={"help": "the location of the ref model name or path"},
    )

def prepare_dataset_dpo(
    dataset
):
    preference_data = []
    for sample in dataset:
        # filter response
        responses, rewards = sample['response'], sample['rewards']
                
        if len(rewards) == 0:
            continue
        # split responses into chosen and rejected
        chosen_idx = [i for i in range(len(rewards)) if rewards[i] == 1]
        reject_idx = [i for i in range(len(rewards)) if rewards[i] == 0]
        if len(chosen_idx) == len(rewards):
            # all correct: continue
            continue
        elif len(reject_idx) == len(rewards):
            # continue
            chosen_idx = [None]
            reject_idx = random.sample(reject_idx, k=1)
        else:
            # sample one chosen and one reject
            chosen_idx = random.sample(chosen_idx, k=1)
            if len(reject_idx) > 1:
                reject_idx = random.sample(reject_idx, k=2)
                chosen_idx.append(None)
            else:
                reject_idx = random.sample(reject_idx, k=1)

        for c_idx, r_idx in zip(chosen_idx, reject_idx):
            golden = sample['answer'] if 'answer' in sample else sample['output']
            chosen = responses[c_idx] if c_idx is not None else golden
            # chosen = responses[c_idx]
            rejected = responses[r_idx]
            sample_instance = deepcopy(sample)
            sample_instance['chosen'] = chosen
            sample_instance['rejected'] = rejected
            preference_data.append(sample_instance)
    return preference_data

def process_bird(dataset):
    for i in range(len(dataset)):
        dataset[i]['answer'] = dataset[i]['output'].split('\t----- bird -----\t')[0]
        dataset[i].pop('output')
    return dataset

def process_math(dataset):
    for i, sample in enumerate(dataset):
        responses, rewards = [], []
        for res, reward in zip(sample['response'], sample['rewards']):
            if reward == 0:
                # reject
                responses.append(res)
                rewards.append(reward)
                continue
            if len(res) < 3500 and 'answer is' in res:
                responses.append(res)
                rewards.append(reward)
        dataset[i]['response'] = responses
        dataset[i]['rewards'] = rewards
    return dataset

def load_dataset(
    data_dir,
    sanity_check=False,
    eval_ratio=0.0,
    prompt_type='org',
    tokenizer=None,
    iterative=False
):
    """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts are structured as follows:
      "Question: " + <prompt> + "\n\nAnswer: "
    """
    dataset = load_json(data_dir)
    if 'bird' in data_dir:
        dataset = process_bird(dataset)
    elif 'math' in data_dir:
        dataset = process_math(dataset)

    if iterative:
        # dataset = prepare_dataset_iterative(dataset)
        dataset = prepare_dataset_dpo(dataset)
    else:
        dataset = prepare_dataset_dpo(dataset)
    print('Total preference data: ', len(dataset))
    
    # FORMAT the data
    dataset = [{'prompt': prompt_dict[prompt_type].format_map(instance),
                'chosen': instance['chosen'],
                'rejected': instance['rejected']} for instance in dataset]
    
    if prompt_type != 'math' and 'mistral' not in tokenizer.name_or_path:
        if prompt_type == 'bird':
            system_prompt = prompt_dict['bird_sys']
        elif prompt_type != "math":
            system_prompt = prompt_dict['sql_sys']
            
        for i in range(len(dataset)):
            dataset[i]['prompt'] = tokenizer.apply_chat_template([{"role": "system", "content": system_prompt},
                                                            {"role": "user", "content": dataset[i]['prompt']}], tokenize=False, add_generation_prompt=True)

    dataset = Dataset.from_list(dataset)

    if eval_ratio > 0:
        dataset= dataset.train_test_split(test_size=int(eval_ratio*len(dataset)), shuffle=True)
        train_dataset, eval_dataset = dataset['train'], dataset['test']
    else:
        train_dataset, eval_dataset = dataset, None

    if sanity_check:
        train_dataset = train_dataset.select(range(0, 1000))

    return train_dataset, eval_dataset

if __name__ == "__main__":
    parser = H4ArgumentParser(ScriptArguments)
    script_args = parser.parse()
    print(script_args)
    
    set_seed(script_args.seed)

    # 1. load a pretrained model
    torch_dtype = torch.float
    if script_args.torch_dtype == "float16":
        torch_dtype = torch.float16
    elif script_args.torch_dtype == "bfloat16":
        torch_dtype = torch.bfloat16

    model = AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        # low_cpu_mem_usage=True,
        torch_dtype=torch_dtype,
        # load_in_4bit=script_args.load_in_4bit,
        # device_map={"": Accelerator().local_process_index},
    )
    model.config.use_cache = False

    if script_args.ignore_bias_buffers:
        # torch distributed hack
        model._ddp_params_and_buffers_to_ignore = [
            name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
        ]

    if 'llama' in script_args.model_name_or_path.lower():
        tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
        if tokenizer.pad_token_id is None:  tokenizer.pad_token_id = 128002
    else:
        tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, add_eos_token=False, add_bos_token=False)
        if tokenizer.pad_token is None:  tokenizer.pad_token = tokenizer.eos_token

    # 2. Load the Stack-exchange paired dataset
    train_dataset, eval_dataset = load_dataset(script_args.data_path, 
                                               sanity_check=script_args.sanity_check, 
                                               eval_ratio=script_args.eval_ratio,
                                               prompt_type=script_args.prompt_type,
                                               tokenizer=tokenizer,
                                               iterative=script_args.iterative)

    if script_args.lora_r is not None:
        peft_config = LoraConfig(
            r=script_args.lora_r,
            lora_alpha=script_args.lora_alpha,
            lora_dropout=script_args.lora_dropout,
            target_modules=[
                "q_proj",
                "v_proj",
                "k_proj",
                "out_proj",
            ],
            bias="none",
            task_type="CAUSAL_LM",
        )
        ref_model = None
    else:
        peft_config = None
        ref_model = AutoModelForCausalLM.from_pretrained(
            script_args.model_name_or_path,
            torch_dtype=torch_dtype,
        )

    if script_args.ref_model_path is not None:
        ref_model = AutoModelForCausalLM.from_pretrained(
            script_args.ref_model_path,
            torch_dtype=torch_dtype,
        )

    # 5. initialize the DPO trainer
    if script_args.dpo_type == 'dpo':
        dpo_trainer = DPOTrainer(
            model,
            ref_model=ref_model,
            args=script_args,
            beta=script_args.beta,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            peft_config=peft_config,
            max_prompt_length=script_args.max_prompt_length,
            max_length=script_args.max_length,
    )
    elif script_args.dpo_type == 'dpop':
        dpo_trainer = DPOPTrainer(
            model=model,
            ref_model=ref_model,
            args=script_args,
            beta=script_args.beta,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            peft_config=peft_config,
            max_prompt_length=script_args.max_prompt_length,
            max_length=script_args.max_length,
            dpop_lambda=script_args.dpop_lambda,
    )
    elif script_args.dpo_type == 'udpo':
        dpo_trainer = UDPOTrainer(
            model=model,
            ref_model=ref_model,
            args=script_args,
            beta=script_args.beta,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            peft_config=peft_config,
            max_prompt_length=script_args.max_prompt_length,
            max_length=script_args.max_length,
            udpo_windows=script_args.udpo_windows,
            conf_threshold=script_args.conf_threshold,
    )

    # 6. train
    train_result = dpo_trainer.train()
    dpo_trainer.save_model(script_args.output_dir)

    metrics = train_result.metrics
    dpo_trainer.log_metrics("train", metrics)
    dpo_trainer.save_metrics("train", metrics)
    dpo_trainer.save_state()
    print('*** Training complete ***')

    # 7. save
    output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
    dpo_trainer.model.save_pretrained(output_dir)