import torch
from datasets import load_dataset, Dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from trl import SFTTrainer
import os
from time import perf_counter
import argparse
import pandas as pd
import json
import wandb
from datetime import datetime
import random

def fix_untrained_tokens(model, eps = 1e-16):
    """
    Llama-3 for eg has untrained vectors in the base model.
    These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
    We reset them to the mean of the rest of the tokens
    """
    embedding_matrix = model.get_input_embeddings ().weight.data
    lm_head_matrix   = model.get_output_embeddings().weight.data

    # Get untrained tokens
    indicator_untrained = torch.amax(embedding_matrix, axis = 1) <= eps
    where_untrained = torch.where(indicator_untrained)[0]
    n_untrained = where_untrained.shape[0]
    n_trained = embedding_matrix.shape[0] - n_untrained
    if n_untrained != 0:
        print(
            f"Not an error, but your model has {n_untrained} untrained tokens.\n"\
            "We shall set them to the mean of the other trained tokens."
        )
    pass

    embedding_matrix[where_untrained] = 0
    lm_head_matrix  [where_untrained] = 0

    sum_embedding  = torch.sum(embedding_matrix, dtype = torch.float32, axis = 0)
    sum_lm_head    = torch.sum(lm_head_matrix,   dtype = torch.float32, axis = 0)

    mean_embedding = (sum_embedding / n_trained).to(embedding_matrix.dtype)
    mean_lm_head   = (sum_lm_head   / n_trained).to(lm_head_matrix  .dtype)

    embedding_matrix[where_untrained] = mean_embedding
    lm_head_matrix  [where_untrained] = mean_lm_head

    return mean_embedding, mean_lm_head

def get_model_and_tokenizer(model_path,device):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", bnb_4bit_use_double_quant=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_path, quantization_config=bnb_config, device_map=device
    )
    model.config.use_cache=False
    model.config.pretraining_tp=1
    return model, tokenizer

def formatted_prompt(question)-> str:
    prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
    return prompt


def formatted_train(input,response)->str:
    prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an expert in prompt engineering. Help me optimize the prefix (between the special token '<|Begin_of_Fairwords|>' and '<|End_of_Fairwords|>') before a query to ensure it encourages fair, helpful, and harmless responses. The optimized prefix should follow these guidelines:
1. The prefix should prompt fair, helpful, and harmless response but should not be overly specific.
2. Only modify the prefix between the special tokens '<|Begin_of_Fairwords|>' and '<|End_of_Fairwords|>', not the rest of the prompt. Just return the prefix.
3. Avoid overly specific constraints unrelated to the query.
4. Keep the prefix under 30 tokens.
5. Focus on improving the prefix, DO NOT response to the query.
6. Minimize unnecessary changes to the prefix.<|eot_id|>
<|start_header_id|>user<|end_header_id|>{input}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>{response}<|eot_id|>"""
    return prompt

def prepare_train_data(data):
    data_df = pd.DataFrame(data)
    data_df["text"] = [formatted_train(input,response) for input,response in zip(data_df["prompt"],data_df["optimized_fairwords"])]
    data = Dataset.from_pandas(data_df)
    return data

def get_model_param_sum(model):
    total_weights = 0
    for name, params in model.named_parameters():
        total_weights += params.data.sum().item()
    return total_weights

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='The sft parameters.')
    parser.add_argument('--model_path', default='/your/path/to/the/Llama3.2-3b-model',type=str,help='The model path')
    parser.add_argument('--output_path', default='/your/output/path/of/the/sft/model',type=str,help='The sft model path')
    parser.add_argument('--train_data', default='../data/fmtrain.json',type=str,help='The training data path')
    parser.add_argument('--per_device_train_batch_size', default=4,type=int)
    parser.add_argument('--gradient_accumulation_steps', default=4,type=int)
    parser.add_argument('--max_grad_norm', default=1,type=int)
    parser.add_argument('--lr', default=1e-5,type=float)
    parser.add_argument('--epochs', default=10,type=int)
    parser.add_argument('--gpu_id', default=0,help='The gpu id')
    parser.add_argument('--wandb_name',help='The name of record in wandb')
    args = parser.parse_args()

    run_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    wandb.init(project='fairmaker', name=f'{args.wandb_name}-{run_time}')

    if torch.cuda.is_available():
        device = torch.device(f"cuda:{args.gpu_id}")
    else:
        device = torch.device("cpu")
    model_path = args.model_path
    model, tokenizer = get_model_and_tokenizer(model_path,device)
    output_path=f'{args.output_path}_{run_time}'
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    with open(args.train_data, 'r') as f:
        data = json.load(f)
    random.seed(42)
    random.shuffle(data)
    train_data = prepare_train_data(data)

    peft_config = LoraConfig(
        r=8, 
        lora_alpha=16, 
        lora_dropout=0.05, 
        bias="none", 
        task_type="CAUSAL_LM"
    )

    training_arguments = TrainingArguments(
        output_dir=output_path,
        per_device_train_batch_size = args.per_device_train_batch_size,
        gradient_accumulation_steps = args.gradient_accumulation_steps,
        max_grad_norm= args.max_grad_norm,
        optim="paged_adamw_32bit",
        learning_rate=args.lr,
        lr_scheduler_type="cosine",
        save_strategy="epoch",
        logging_steps=20,
        num_train_epochs=args.epochs,
        fp16=True,
        run_name=f'{args.wandb_name}-{run_time}',
    )

    trainer = SFTTrainer(
        model=model,
        train_dataset=train_data,
        peft_config=peft_config,
        dataset_text_field="text",
        args=training_arguments,
        tokenizer=tokenizer,
        packing=False,
        max_seq_length=1024
    )

    trainer.train()
    wandb.finish()
    trainer.save_model(output_path)
    sft_model, sft_tokenizertokenizer = get_model_and_tokenizer(output_path,device)
    weight_before_merge = get_model_param_sum(model)
    weight_after_merge = get_model_param_sum(sft_model)
    print(f"weight before merge: {weight_before_merge}\nweight after merge: {weight_after_merge}")

