# this assumes that the preference pairs are already constrcuted as json files with "prompt", "chosen", and "rejected" as keys


import torch
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, GenerationConfig
from tqdm.auto import tqdm
from trl import DPOTrainer, DPOConfig
import itertools
import os
import argparse 
import wandb
import sys
import json
import transformers
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.covariance import MinCovDet



def get_num_tokens(generation):  # generation: num_seq x max(num_tokens)
    num_tokens = []
    for ids in generation:
        count = 0
        for id in ids:
            if id>2:
                count+=1
        num_tokens.append(count+1)
    return num_tokens


def getEigenIndicator_v2(hidden_states, num_tokens):
    alpha = 1e-3
    LayerEigens = []
    if len(hidden_states)<2:
        return 0, "None"
    for layer_ind in range(len(hidden_states[0])):
        last_embeddings = torch.zeros(hidden_states[1][-1].shape[0], hidden_states[1][-1].shape[2]).to("cuda")
        for seq_ind in range(hidden_states[1][-1].shape[0]):
            for token_ind in range(len(hidden_states)-1):
                if token_ind > num_tokens[seq_ind]-1:
                    continue
                last_embeddings[seq_ind,:] += hidden_states[token_ind+1][layer_ind][seq_ind,0,:]
            last_embeddings[seq_ind,:] = last_embeddings[seq_ind,:]/(num_tokens[seq_ind]-1)
        CovMatrix = torch.cov(last_embeddings).cpu().numpy().astype(float)
        u, s, vT = np.linalg.svd(CovMatrix+alpha*np.eye(CovMatrix.shape[0]))
        eigenIndicator = np.mean(np.log10(s))
        LayerEigens.append(eigenIndicator)
    LayerEigens = np.array(LayerEigens)
    # print("LayerEigens: ", LayerEigens)
    return np.mean(LayerEigens[20:-2]), s   # averages eigenscores from layer 20 to the second-to-last layer 



MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
LEARNING_RATE = 2e-4
BATCH_SIZE = 1
NUM_EPOCHS = 10

#############################

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=cache_dir
                                          )
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=cache_dir
                                             )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Model loaded on %s" % device)


def main():
    # Check if filename argument is provided
    if len(sys.argv) < 2:
        print("Usage: python dpo_pset_code.py <number_of_objectives> <preference type>")
        return
    
    # Initialize the argument parser
    parser = argparse.ArgumentParser(description="Process a file")
    parser.add_argument('--preference_data', type=str, help='the preference data')
    parser.add_argument('--num_generations_per_prompt', type=int, default=10)
    parser.add_argument('--test_data', type=str, required=True)
    parser.add_argument('--model_path', type=str, required=True, help='path to save the fine-tuned model')
    parser.add_argument('--top_p', type=float, default=0.99)
    parser.add_argument('--temperature', type=float, default=1)
    parser.add_argument('--top_k', type=int, default=10)
    parser.add_argument('--model_name', type=str, default = "meta-llama/Meta-Llama-3-8B-Instruct", help='name of model output file')
    parser.add_argument('--output', type=str, help='name of inference output file')
    
    # Parse the arguments
    args = parser.parse_args()
    preference_data = args.preference_data
    output_file = args.output
    model_name = args.model_name
    with open(preference_data, 'r') as f:
        PREFERENCE_DATA = json.load(f)

    prompt_list = [data['prompt'] for data in PREFERENCE_DATA]
    chosen_list = [data['chosen'] for data in PREFERENCE_DATA]
    rejected_list = [data['rejected'] for data in PREFERENCE_DATA]
    position_list = ['support' for _ in range(len(PREFERENCE_DATA))]
    train_dataset = Dataset.from_dict({'prompt': prompt_list, 'position': position_list, 'chosen': chosen_list, 'rejected': rejected_list})


    training_args = DPOConfig(
    output_dir="llama",
    logging_steps=10,
    per_device_train_batch_size=1,
    save_only_model=True,  
    learning_rate=LEARNING_RATE
)


    peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=8,
    bias="none",
    task_type="CAUSAL_LM",
)

    dpo_trainer = DPOTrainer(
    model,
    args=training_args,
    train_dataset=train_dataset,
    processing_class=tokenizer,
    peft_config=peft_config,
)

    dpo_trainer.train()

    dpo_trainer.save_model(args.model_path)

#############################
# inference
    print("Running inference now:)!")
    df_test = pd.read_csv(args.test_data)
    batch_generations = []
    eigen_scores = []
    model.eval()

    for idx, row in tqdm(df_test.iterrows(), total=len(df_test), desc="Evaluating"):
        prompt = row['prompt']
        generations = []
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt},
        ]

        inputs = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(model.device)
        terminators = [
            tokenizer.eos_token_id,
            tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        input_length = inputs.shape[1]
        generation_config = transformers.GenerationConfig(
    max_new_tokens=256,  # change this back to 256
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=terminators
)
        torch.cuda.empty_cache()


        num_gens_remaining = args.num_generations_per_prompt  # or set from args if you want
        with torch.no_grad():
            while num_gens_remaining > 0:
                batch_size = min(10, num_gens_remaining)
                dict_outputs = model.generate(
                    inputs,
                    do_sample=True,
                    top_p=args.top_p,
                    top_k=args.top_k,
                    temperature=args.temperature,
                    num_return_sequences=batch_size,
                    generation_config=generation_config,
                    output_hidden_states=True,
                    return_dict_in_generate=True,
                    output_scores=False
                )
                generation = dict_outputs.sequences[:, input_length:].cpu()
                generations.extend(generation)
                num_gens_remaining -= batch_size
                
                num_tokens = get_num_tokens(generation)
                E, _ = getEigenIndicator_v2(dict_outputs.hidden_states, num_tokens)
                eigen_scores.append(E)
                del dict_outputs, generation
                torch.cuda.empty_cache()
            
        
        generations = torch.nested.nested_tensor(generations).to_padded_tensor(tokenizer.eos_token_id)
        generations = generations.reshape(-1, generations.shape[-1])[:10]
        generated_texts = [tokenizer.decode(_, skip_special_tokens=True) for _ in generations]
        batch_generations.append(generated_texts)

    df_test["batch_generations"] = batch_generations
    df_test["eigenscore_new"] = eigen_scores
    df_test.to_csv(args.output, index=False)
    print(f"Saved generations and eigenscores to {args.output}")

if __name__ == "__main__":
    main()
