import os
import torch
import random
from itertools import chain
from datasets import Dataset
from datasets import load_from_disk
from peft import PeftModel
from data_modules.base_data import load_arxiv_train_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from trl import (
    PPOConfig,
    PPOTrainer,
    AutoModelForCausalLMWithValueHead,
)

os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")

import argparse

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--seed", type=int, required=True)
    p.add_argument("--epoch", type=int, required=True)  
    p.add_argument("--data_epoch", type=int, required=False, default=1)
    p.add_argument("--k_slice", type=int, required=True)  
    p.add_argument("--learning_rate", type=float, default=1.5e-4)
    p.add_argument("--per_device_train_batch_size", type=int, default=64)
    p.add_argument("--dataset", type=str, required=True)
    p.add_argument("--dataset_path", type=str, required=True)
    p.add_argument("--model_family", type=str, required=True)
    p.add_argument("--base_model_path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
    p.add_argument("--policy_model_path", type=str, required=True)
    p.add_argument("--ref_model_path", type=str, required=True)
    p.add_argument("--reward_base_model_path", type=str, required=True)
    p.add_argument("--reward_model_path", type=str, required=True)
    p.add_argument("--output_dir", type=str, required=True)
    p.add_argument("--response_length", type=int, required=True)
    p.add_argument("--class_num", type=int, required=True)
    p.add_argument("--forget_label", type=int, required=True)
    return p.parse_args()


if __name__ == "__main__":
    args = parse_args()
    seed = args.seed
    epoch = args.epoch
    k_slice = args.k_slice

    device = torch.device("cuda:0")   
    tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    print("pad token id:", tokenizer.pad_token)
    # config = LlamaConfig.from_pretrained(args.base_model_path, gradient_checkpointing=True)

    # get reward model
    base_model = AutoModelForSequenceClassification.from_pretrained(args.reward_base_model_path, num_labels=args.class_num)
    base_model.config.pad_token_id = tokenizer.pad_token_id
    reward_model = PeftModel.from_pretrained(base_model, args.reward_model_path)
    reward_model = reward_model.merge_and_unload()
    reward_model = reward_model.to(device)
    reward_model.eval()

    # ablation study
    # reward_model = AutoModelForSequenceClassification.from_pretrained(args.reward_model_path, num_labels=args.class_num, ignore_mismatched_sizes=True)
    # reward_model = reward_model.to(device)
    # reward_model.eval()

    # wtm
    # wtm_tokenizer = AutoTokenizer.from_pretrained(args.reward_model_path)
    # base_model = AutoModelForCausalLM.from_pretrained(
    #     args.reward_base_model_path,
    #     low_cpu_mem_usage=True, 
    #     torch_dtype=torch.bfloat16
    # )
    # base_model.resize_token_embeddings(len(tokenizer.get_vocab())+2+args.class_num)
    # base_model.config.pad_token_id = tokenizer.pad_token_id
    # reward_model = PeftModel.from_pretrained(base_model, args.reward_model_path)
    # reward_model = reward_model.merge_and_unload()
    # reward_model = reward_model.to(device)
    # reward_model.eval()
    # token_id = wtm_tokenizer.convert_tokens_to_ids("[WTM]")
    # print(token_id)
    # breakpoint()
    # get value model (use TRL value head) 
    value_model = AutoModelForCausalLMWithValueHead.from_pretrained(
        args.base_model_path,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
    )
    value_model.to(device)
    value_model.score = value_model.v_head
    if not hasattr(value_model, "base_model_prefix"):
        value_model.base_model_prefix = value_model.pretrained_model.base_model_prefix
    setattr(value_model, value_model.base_model_prefix, getattr(value_model.pretrained_model, value_model.pretrained_model.base_model_prefix))
    
    for p in value_model.parameters():
        p.requires_grad = False
    for name, p in value_model.named_parameters():
        if any(k in name for k in ["layers.30", "layers.31"]): 
            p.requires_grad = True

    for p in value_model.v_head.parameters():
        p.requires_grad = True
    value_model.train()

    # get policy model
    new_model = AutoModelForCausalLM.from_pretrained(
        args.base_model_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16,
    )
    new_model = new_model.to(device)
    new_model.config.use_cache = False
    new_model.generation_config.do_sample = True
    policy_model = PeftModel.from_pretrained(new_model, args.policy_model_path)    
    policy_model = policy_model.to(device)
    policy_model.train()

    # get ref model                                       
    ref_model = AutoModelForCausalLM.from_pretrained(
        args.base_model_path, 
        low_cpu_mem_usage=True, 
        torch_dtype=torch.bfloat16
    )
    ref_model = PeftModel.from_pretrained(ref_model, args.ref_model_path)
    for p in ref_model.parameters():
        p.requires_grad = False
    ref_model.eval()
    ref_model.to(device)
    ref_model.config.use_cache = False

    if args.dataset == "TOFU":
        train_data = load_from_disk(args.dataset_path)
        dataset_text_field = "question"
        dataset_label_field = "answer_split"
        train_data = train_data.map(lambda x: {
            **x,
            "answer_split": x["answer_split"].replace('\u200b', '').replace('\u200c', '')
        })
    elif args.dataset == "arxiv":
        data_config = {
            "hf_dataset_name": args.dataset_path,
            "hf_dataset_split": "unwatermarked_forget_01",
            "is_wtm": False,
            "forget_ratio": 0.05,
        }
        train_data, _, _ = load_arxiv_train_dataset(**data_config)
        dataset_text_field = "text"
        dataset_label_field = "answer_split"

    def prepare_dataset(args, dataset, tokenizer):
        def tokenize(element, idx):
            if args.model_family == "llama":
                texts = ["[INST] " + s + " [/INST]" for s in element[dataset_text_field]]
            elif args.model_family == "qwen":
                texts = ["<|im_start|>user\n" + s + "<|im_end|>\n<|im_start|>assistant\n" for s in element[dataset_text_field]]
            outputs = tokenizer(texts, padding=False)

            # wm_labels = [
            #     [0] if i < len(dataset) - 800 else
            #     [1] if i < len(dataset) - 400 else
            #     [2]
            #     for i in idx
            # ]

            wm_labels = [[0] if i < len(dataset)-400 else [1] for i in idx]
            return {
                "input_ids": outputs["input_ids"],
                "wm_label_text": wm_labels, 
            }

        return dataset.map(
            tokenize,
            batched=True,
            with_indices=True, 
            remove_columns=dataset.column_names,
        )
    
    # breakpoint()
    train_data = prepare_dataset(args, train_data, tokenizer)
    print(len(train_data))
    idx = list(range(len(train_data)-80, len(train_data)))
    train_data1 = train_data.select(idx)

    train_data = train_data.select(range(0, len(train_data)-400))
    # n = len(train_data)
    # idx = list(range(0, n-400)) + list(range(n-80, n))
    # train_data = train_data.select(idx)

    indices = random.sample(range(len(train_data)), 400)
    train_data = train_data.select(indices)

    min_len = min(len(train_data1), len(train_data))
    interleaved_indices = [i for pair in zip(range(min_len), range(min_len, 2 * min_len)) for i in pair]
    train_data = Dataset.from_dict({
        key: list(chain.from_iterable(zip(train_data1[key][:min_len], train_data[key][:min_len])))
        for key in train_data1.column_names
    })

    print(len(train_data))
    ppo_config = PPOConfig(
        kl_coef=0.1,
        vf_coef=0.2,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.per_device_train_batch_size,
        num_ppo_epochs=epoch,
        num_train_epochs=args.data_epoch,
        logging_steps=10,
        save_strategy="no",
        gradient_accumulation_steps=1,
        gradient_checkpointing=False,
        num_sample_generations=0,
        # output_dir="wtm_ppo_outputs",
        bf16=True,
        temperature=0.8,
        cliprange = 0.2,
        response_length=args.response_length,
    )
    # breakpoint()
    trainer = PPOTrainer(
        args=ppo_config,
        processing_class=tokenizer,
        model=policy_model,
        ref_model=ref_model,
        reward_model=reward_model,
        value_model=value_model,
        train_dataset=train_data,
        eval_dataset=None,
        peft_config = None,
        device=device,
        target_wtm = args.forget_label,
        k_slice = k_slice,
        class_num = args.class_num, 
    )
    trainer.train()
    path=args.output_dir
    trainer.save_model(path)
    print(path)
