import logging
import random
import sys
import datasets
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from transformers import AutoModelForSequenceClassification
import argparse
import gc
import os
import json
import shutil

import numpy as np
from alignment import (
    DataArguments,
    H4ArgumentParser,
    ModelArguments,
    get_checkpoint,
    get_datasets,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
    get_tokenizer,
    is_adapter_model,
)
from alignment.data import maybe_insert_system_message, is_openai_format
from peft import PeftConfig, PeftModel
from simpo_trainer import SimPOTrainer
from simpo_config import SimPOConfig
from dataclasses import dataclass, field
from typing import Optional, Literal
import warnings
from tqdm import tqdm
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)

MISTRAL_CHAT_TEMPLATE = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"

from vllm import LLM, SamplingParams



def merge_models(model_path, ref_model_path, torch_dtype, trust_remote_code, temp_path="outputs/temp_path", alpha=0.5, merged_name=None):# lm_head up_proj down_proj o_proj 

    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
    ref_model = AutoModelForCausalLM.from_pretrained(ref_model_path, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
    print(ref_model)


    model_params = model.state_dict()
    ref_model_params = ref_model.state_dict()


    merged_params = {}
    for name in model_params.keys():
        if merged_name!=None and merged_name not in name:
            merged_params[name] = ref_model_params[name]
            continue
        else:
            theta_model = model_params[name]
            theta_ref_model = ref_model_params[name]
            merged_params[name] = alpha * theta_model + (1 - alpha) * theta_ref_model


    model.load_state_dict(merged_params)


    model.save_pretrained(temp_path)
    del model, ref_model, merged_params, model_params, ref_model_params
    gc.collect()
    print(f"The merged model has been saved to {temp_path}")


def reward_annotate(output_data, reward_model_path="RLHFlow/ArmoRM-Llama3-8B-v0.1"):
    inputs = [data["prompt"] for data in output_data]
    candidates_texts = [data["all_generated_responses"] for data in output_data]

    model = AutoModelForSequenceClassification.from_pretrained(reward_model_path, 
                                                            device_map="cuda", 
                                                            trust_remote_code=True, torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(reward_model_path, use_fast=True)

    for data in tqdm(output_data):
        prompt = data["prompt"]
        candidates = data["all_generated_responses"]
        scores = []
        for candidate in candidates:
            messages = [{"role": "user", "content": prompt},
                        {"role": "assistant", "content": candidate}]
            input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
            with torch.no_grad():
                output = model(input_ids)
                score = output.score.float().item()
                scores.append(score)
        data["all_rm_scores"] = scores
        data["weight_chosen"] = 1

    # Binarize data: win = highest scoring reponse; lose = lowest scoring response
    for data in output_data:
        chosen_idx = np.argmax(data["all_rm_scores"])
        rejected_idx = np.argmin(data["all_rm_scores"])
        chosen = []
        chosen.append({
            "role": "user",
            "content": data["prompt"]
        })
        chosen.append({
            "role": "assistant",
            "content": data["all_generated_responses"][chosen_idx]
        })
        rejected = []
        rejected.append({
            "role": "user",
            "content": data["prompt"]
        })
        rejected.append({
            "role": "assistant",
            "content": data["all_generated_responses"][rejected_idx]
        })
        data.update({
            "chosen": chosen,
            "rejected": rejected,
        })
    output_data = datasets.Dataset.from_list(output_data)
    del model, tokenizer
    return output_data

def rollout(model_path, dataset, n=5, max_tokens=4096, top_p=0.95, temperature=0.8):
    prompts = dataset['prompt']
    chosens = dataset['ori_chosen']
    rejecteds = dataset['ori_rejected']
    # prompt_ids = dataset['prompt_id']

    model=LLM(model_path,trust_remote_code=True, tensor_parallel_size=1, gpu_memory_utilization=0.7, max_model_len=max_tokens)
    tokenizer = model.get_tokenizer()
    sampling_params = SamplingParams(max_tokens=max_tokens,
                                    top_p=top_p,
                                    temperature=temperature,
                                    n=n)

    conversations = [tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}],sampling_params=sampling_params, tokenize=False, add_generation_prompt=True) for prompt in prompts]

    outputs = model.generate(conversations, use_tqdm=True,
                                 sampling_params=sampling_params)
    # Save the outputs as a JSON file.
    output_data = []
    num_identical = 0
    for i, output in enumerate(outputs):
        prompt = output.prompt
        gen_text = []
        for generated_text in output.outputs:
            gen_text.append(generated_text.text)
        if len(set(gen_text)) == 1:
            # filter out samples where all generated responses are identical
            num_identical += 1
            continue

        output_data.append({
            # 'prompt_id': prompt_ids[i],
            'prompt': prompts[i],
            'all_generated_responses': gen_text,
            'ori_chosen': chosens[i],
            'ori_rejected': rejecteds[i],
        })
    print(f"Filtered out {num_identical} samples with identical generated responses")
    del model
    return output_data

def adjust_proportion(dataset, on_policy_data_proportion):
    print("=====================================================================")
    print("on_policy_data_proportion ", on_policy_data_proportion)
    print("on_policy_data_proportion ", on_policy_data_proportion)
    print("on_policy_data_proportion ", on_policy_data_proportion)
    print("=====================================================================")
    num_rows = len(dataset)

    num_keep_chosen = int(num_rows * on_policy_data_proportion)


    chosen_indices = random.sample(range(num_rows), num_keep_chosen)

    mask = np.zeros(num_rows, dtype=bool)
    mask[chosen_indices] = True

    def map_function(example, idx):
        if mask[idx]:
            return {
                'chosen': example['chosen'],
                'rejected': example['rejected']
            }
        else:
            return {
                'chosen': example['ori_chosen'],
                'rejected': example['ori_rejected']
            }


    new_dataset = dataset.map(map_function, with_indices=True)
    return new_dataset

def apply_chat_template(
    example,
    tokenizer,
    task: Literal["sft", "generation", "rm", "simpo"],
    auto_insert_empty_system_msg: bool = True,
    change_template = None,
):
    if change_template == "mistral":
        tokenizer.chat_template = MISTRAL_CHAT_TEMPLATE
    if task in ["sft", "generation"]:
        messages = example["messages"]
        # We add an empty system message if there is none
        if auto_insert_empty_system_msg:
            maybe_insert_system_message(messages, tokenizer)
        example["text"] = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True if task == "generation" else False,
        )
    elif task == "rm":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            chosen_messages = example["chosen"]
            rejected_messages = example["rejected"]
            # We add an empty system message if there is none
            if auto_insert_empty_system_msg:
                maybe_insert_system_message(chosen_messages, tokenizer)
                maybe_insert_system_message(rejected_messages, tokenizer)

            example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
            example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
        else:
            raise ValueError(
                f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
            )
    elif task == "simpo":
        if all(k in example.keys() for k in ("chosen", "rejected")):
            if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]):
                raise ValueError(
                    f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages"
                )

            # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
            # We therefore need to extract the N-1 turns to form the prompt
            if "prompt" in example and is_openai_format(example["prompt"]):
                prompt_messages = example["prompt"]
                chosen_messages = example["chosen"]
                rejected_messages = example["rejected"]
            else:
                prompt_messages = example["chosen"][:-1]
                # Now we extract the final turn to define chosen/rejected responses
                chosen_messages = example["chosen"][-1:]
                rejected_messages = example["rejected"][-1:]

            # Prepend a system message if the first message is not a system message
            if auto_insert_empty_system_msg:
                maybe_insert_system_message(prompt_messages, tokenizer)

            example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
            example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
            if example["text_chosen"].startswith(tokenizer.bos_token):
                example["text_chosen"] = example["text_chosen"][len(tokenizer.bos_token):]
            example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
            if example["text_rejected"].startswith(tokenizer.bos_token):
                example["text_rejected"] = example["text_rejected"][len(tokenizer.bos_token):]
        else:
            raise ValueError(
                f"Could not format example as dialogue for `{task}` task! Require either the "
                f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}"
            )
    else:
        raise ValueError(
            f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']"
        )
    return example

def get_number(s):
    start_index = s.find("_cur_")
    num = float(s[start_index-1:start_index])
    return num
def get_reset_checkpoint(path):
    checkpoint_folders = []
    for item in os.listdir(path):
        item_path = os.path.join(path, item)
        if os.path.isdir(item_path) and item.startswith('checkpoint-'):
            try:
                number = int(item.split('-')[1])
                checkpoint_folders.append((number, item_path))
            except ValueError:
                continue
    if not checkpoint_folders:
        return None
    max_number, max_path = max(checkpoint_folders)
    return max_path

"""
ACCELERATE_LOG_LEVEL=info  accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/Llama-3.2-3B-Instruct/dpo.yaml
"""

def main():
    parser = H4ArgumentParser((ModelArguments, DataArguments, SimPOConfig))
    model_args, data_args, training_args = parser.parse()
    #######
    # Setup
    #######
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
    training_args.gradient_accumulation_steps= 64 // num_gpus # Same with dpo
    training_args.save_steps = 64
    training_args.chunk_num=1
    training_args.is_reset=False
    training_args.save_steps=100000 # do not use checkpoint
    training_args.save_strategy="no"
    training_args.do_eval=False

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    # logger.info(f"Model parameters {model_args}")
    # logger.info(f"Data parameters {data_args}")
    # logger.info(f"Training/evaluation parameters {training_args}")
    # # Check for last checkpoint
    last_checkpoint = get_checkpoint(training_args)
    if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

    # # Set seed for reproducibility
    set_seed(training_args.seed)
    
    # ###############
    # # Load datasets
    # ###############
    raw_datasets = get_datasets(
        data_args,
        splits=data_args.dataset_splits,
        configs=data_args.dataset_configs,
        columns_to_keep=["chosen", "rejected", "prompt", "ori_chosen", "ori_rejected"],
    )
    logger.info(
        f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
    )
    column_names = raw_datasets["train"].features
    training_num = raw_datasets["train"].num_rows
    # #####################################
    # # Load tokenizer and process datasets
    # #####################################
    data_args.truncation_side = "left"  # Truncate from left to ensure we don't lose labels in final turn
    tokenizer = get_tokenizer(model_args, data_args)
    if "mistral" in model_args.model_name_or_path.lower():
        change_template = "mistral"
    else:
        change_template = None

    ###############################################
    # Replay Ratio
    n = training_args.chunk_num
    chunk_size = len(raw_datasets["train"]) // n
    raw_datasets_all_chunk = []
    for i in range(n):
        start = i * chunk_size
        end = start + chunk_size if i < n - 1 else len(raw_datasets["train"])
        chunk_train = raw_datasets["train"][start:end]
        if isinstance(chunk_train, dict):

            chunk_train = datasets.Dataset.from_dict(chunk_train)
        chunk = datasets.DatasetDict({
            "train": chunk_train,
            "test": raw_datasets["test"]
        })
        raw_datasets_all_chunk.append(chunk)

    training_args.output_dir = training_args.output_dir + f"-proportion{training_args.on_policy_data_proportion}-sft{training_args.sft_weight}-chunk{n}"
    output_dir_ori = training_args.output_dir
    ###############################################
    if training_args.is_reset:
        folder_names = []
        base_path = output_dir_ori.split("/")[0]
        for p in os.listdir(base_path):
            if os.path.isdir(os.path.join(base_path, p)):
                has_files = any(os.scandir(os.path.join(base_path, p)))
                if has_files:
                    if training_args.sft_reset and "cur_sft" in p:
                        folder_names.append(os.path.join(base_path, p))
                    if not training_args.sft_reset and "cur_pro" in p:
                        folder_names.append(os.path.join(base_path, p))
        run_num = len(folder_names)
        folder_names = sorted(folder_names, key=get_number, reverse=False)

        if run_num>0:
            print("*"*30)
            print(folder_names)
            print("*"*30)
    ###############################################

    my_epoch = int(training_args.num_train_epochs)
    # my_epoch = 3
    for epoch in range(int(training_args.num_train_epochs)):
        training_args.num_train_epochs=1
        prv_raw_datasets_train=None
        for chunk_num, raw_datasets in enumerate(raw_datasets_all_chunk): 
            # #####################

            #######################
            if training_args.is_reset:
                if run_num>0:

                    run_num-=1
                    training_args.output_dir = folder_names[chunk_num]
                    print("now train", training_args.output_dir)
                    continue
                
                if chunk_num==0:
                    model = model_args.model_name_or_path
                else:
                    model = training_args.output_dir
                print("==========================TOBE Rollout Model", model)
                print("==========================TOBE Rollout Model", model)
                print("==========================TOBE Rollout Model", model)

                train_dataset = rollout(model, raw_datasets['train'])
                raw_datasets['train'] = reward_annotate(train_dataset, reward_model_path=training_args.my_reward_model_path)
                del train_dataset
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                print(f"the data_chunk {chunk_num}/{n} has been rollouted and annotated")

                if training_args.sft_reset:
                    training_args.sft_weight =  training_args.reset_size - chunk_num*(training_args.reset_size)/(n - 1)
                else: #proportion reset
                    training_args.on_policy_data_proportion =  1 - chunk_num*(1-training_args.reset_size)/(n-1)

                if chunk_num>0:
                    prv_raw_datasets_train = datasets.load_dataset("json", data_files=os.path.join(training_args.output_dir, "raw_datasets.json"),split="train")
                    raw_datasets["train"] = datasets.concatenate_datasets([prv_raw_datasets_train, raw_datasets["train"]])

            #调节on policy data比例
            raw_datasets["train"] = adjust_proportion(raw_datasets["train"], training_args.on_policy_data_proportion)

            torch_dtype = (
                model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
            )
            quantization_config = get_quantization_config(model_args)

            model_kwargs = dict(
                revision=model_args.model_revision,
                trust_remote_code=model_args.trust_remote_code,
                torch_dtype=torch_dtype,
                use_cache=False if training_args.gradient_checkpointing else True,
                device_map=get_kbit_device_map() if quantization_config is not None else None,
                quantization_config=quantization_config,
                attn_implementation=model_args.attn_implementation,
            )

            training_args.model_init_kwargs = model_kwargs

            model = model_args.model_name_or_path
            ref_model=model
            if training_args.is_reset:
                training_args.save_steps=100000 # do not use checkpoint
                training_args.save_strategy="no"
                if chunk_num==0:
                    ref_model = model_args.model_name_or_path
                else:
                    ref_model = training_args.output_dir
                model=model_args.model_name_or_path

                if training_args.sft_reset:
                    training_args.output_dir = output_dir_ori + f"_{chunk_num}_cur_sft{round(training_args.sft_weight, 2)}"
                else: #proportion reset
                    training_args.output_dir = output_dir_ori + f"_{chunk_num}_cur_pro{round(training_args.on_policy_data_proportion, 2)}"

                if chunk_num>0:
                    merge_models(model, ref_model, torch_dtype=torch_dtype, trust_remote_code=model_args.trust_remote_code,
                                                    temp_path="outputs/temp_path", alpha=training_args.dropout_size, merged_name="o_proj")
                    model="outputs/temp_path"


                raw_datasets["train"].to_json(os.path.join(training_args.output_dir, f'raw_datasets.json'))
            else:
                if my_epoch>1:
                    raw_datasets["train"] = datasets.concatenate_datasets([raw_datasets["train"] for _ in range(my_epoch)])
                    training_args.output_dir = training_args.output_dir+f"_epoch{my_epoch}"

            # #####################
            # # Apply chat template
            # #####################
            if "Phi" in model_args.model_name_or_path:
                os.makedirs(training_args.output_dir, exist_ok=True)
                shutil.copyfile(os.path.join(model_args.model_name_or_path,'configuration_phi3.py'), os.path.join(training_args.output_dir,'configuration_phi3.py'))
                shutil.copyfile(os.path.join(model_args.model_name_or_path,'modeling_phi3.py'), os.path.join(training_args.output_dir,'modeling_phi3.py'))

            if "Qwen" in  model_args.model_name_or_path:
                tokenizer.pad_token = tokenizer.eos_token
                tokenizer.bos_token = tokenizer.eos_token

            raw_datasets = raw_datasets.map(
                apply_chat_template,
                fn_kwargs={
                    "tokenizer": tokenizer,
                    "task": "simpo",
                    "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
                    "change_template": change_template,
                },
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                desc="Formatting comparisons with prompt template",
            )
            # Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected
            for split in ["train", "test"]:
                raw_datasets[split] = raw_datasets[split].rename_columns(
                    {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
                )
            # #######################
            # Instantiate SimPO trainer
            # #######################
            print("#################################TO BE TRAINED####################################")
            print("model", model)
            print("ref_model", ref_model)
            print("#################################TO BE TRAINED####################################")

            trainer = SimPOTrainer(
                model=model,
                ref_model=ref_model,
                args=training_args,
                train_dataset=raw_datasets["train"],
                eval_dataset=raw_datasets["test"],
                tokenizer=tokenizer,
                peft_config=get_peft_config(model_args),
            )
            
            ###############
            # Training loop
            ###############
            checkpoint = None
            if not training_args.is_reset:
                last_checkpoint = get_checkpoint(training_args)
                if training_args.resume_from_checkpoint is not None:
                    checkpoint = training_args.resume_from_checkpoint
                elif last_checkpoint is not None:
                    checkpoint = last_checkpoint
            else:
                pass
                # if chunk_num>0:
                    # checkpoint=get_reset_checkpoint(ref_model)
            print("#################################checkpoint####################################")
            print("checkpoint", checkpoint)
            print("#################################checkpoint####################################")
            train_result = trainer.train(resume_from_checkpoint=checkpoint)

            metrics = train_result.metrics
            metrics["train_samples"] = len(raw_datasets["train"])
            trainer.log_metrics("train", metrics)
            trainer.save_metrics("train", metrics)
            trainer.save_state()

            logger.info("*** Training complete ***")

            ##################################
            # Save model and create model card
            ##################################
            logger.info("*** Save model ***")
            trainer.save_model(training_args.output_dir)
            logger.info(f"Model saved to {training_args.output_dir}")

            # Save everything else on main process
            kwargs = {
                "finetuned_from": model_args.model_name_or_path,
                "dataset": list(data_args.dataset_mixer.keys()),
                "dataset_tags": list(data_args.dataset_mixer.keys()),
                "tags": ["alignment-handbook"],
            }
            if trainer.accelerator.is_main_process:
                trainer.create_model_card(**kwargs)
                # Restore k,v cache for fast inference
                trainer.model.config.use_cache = True
                trainer.model.config.save_pretrained(training_args.output_dir)

            ##########
            # Evaluate
            ##########
            if training_args.do_eval:
                logger.info("*** Evaluate ***")
                metrics = trainer.evaluate()
                metrics["eval_samples"] = len(raw_datasets["test"])
                trainer.log_metrics("eval", metrics)
                trainer.save_metrics("eval", metrics)

            if training_args.push_to_hub is True:
                logger.info("Pushing to hub...")
                trainer.push_to_hub(**kwargs)

            trainer.model = None
            trainer.ref_model = None
            trainer.accelerator.state._reset_state(True)
            trainer.accelerator.gradient_state._reset_state()
            trainer.accelerator = None
            # del trainer
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            logger.info("*** Training complete! ***")


if __name__ == "__main__":
    main()
