# Based on
# huggingface/notebooks/examples/language_modeling_from_scratch.ipynb

# Hugging Face imports
from datasets import load_dataset
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import TrainingArguments, Trainer
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch
import ast
from tqdm import tqdm
import data.helpers as ph
import string
from peft import PeftModel
from data.constants import COUNTRIES, ALPHABET
from data.utils import get_alpaca_prompt, get_options_str
from termcolor import colored
import pandas as pd
pd.set_option('display.max_colwidth', None)
from scipy.stats import wasserstein_distance
import swifter
import json
from functools import partial
import os
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
import numpy as np
import hydra 
from omegaconf import DictConfig, OmegaConf
from datasets import (
    load_dataset, 
    Dataset,
)
from data.anthropic_global_opinions import get_dataset_oqa, oqa_tnp_dataset
from utils import print_trainable_parameters, set_random_seed
import os

class llmodel:
    def __init__(self, config):
        print(f"Loading model from {config.model}")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        lora_weights = config.lora_weights
        tokenizer = AutoTokenizer.from_pretrained(config.model)
        model = AutoModelForCausalLM.from_pretrained(config.model, load_in_8bit=config.use_int8, torch_dtype=torch.float16)
        if config.use_finetune:
            print('------ using finetuned model -------')
            model = PeftModel.from_pretrained(model, lora_weights, torch_dtype=torch.float16)
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id
        


        if config.use_int8:
            model = prepare_model_for_int8_training(model)

        self.tokenizer = tokenizer 
        self.model =  model

    def get_predictions(self, sentence):
        # Encode the sentence using the tokenizer and return the model predictions.
        inputs = self.tokenizer.encode(sentence, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(inputs)
            predictions = outputs[0]
        return predictions
    
    def get_sentence_embedding(self, sentence):
        # Encode the sentence using the tokenizer and feed it to the model.
        inputs = self.tokenizer.encode(sentence, return_tensors="pt")
        with torch.no_grad():
            last_hidden_states = self.model(inputs, output_hidden_states=True).hidden_states[-1]
        last_token_embd = last_hidden_states[:, -1, :]
        print(last_token_embd.shape)
        return last_token_embd
    
    def get_batch_sentence_embeddings(self, sentences):
        # Tokenize a batch of sentences and feed them to the model.
        inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = inputs.to(self.device)  # move to device, e.g. GPU
        with torch.no_grad():
            outputs = self.model(**inputs, output_hidden_states=True)
            # Get the embeddings of the last token of each sentence
            embeddings = outputs.hidden_states[-1][:, -1, :]
        return embeddings
    def get_next_word_probabilities(self, sentence, top_k=500):

        # Get the model predictions for the sentence.
        predictions = self.get_predictions(sentence)
        
        # Get the next token candidates.
        next_token_candidates_tensor = predictions[0, -1, :]

        # Get the top k next token candidates.
        topk_candidates_indexes = torch.topk(
            next_token_candidates_tensor, top_k).indices.tolist()

        # Get the token probabilities for all candidates.
        all_candidates_probabilities = torch.nn.functional.softmax(
            next_token_candidates_tensor, dim=-1)
        
        # Filter the token probabilities for the top k candidates.
        topk_candidates_probabilities = \
            all_candidates_probabilities[topk_candidates_indexes].tolist()
        

        # Decode the top k candidates back to words.
        topk_candidates_tokens = \
            [self.tokenizer.decode([idx]).strip() for idx in topk_candidates_indexes]

        # Return the top k candidates and their probabilities.

        return list(zip(topk_candidates_tokens, topk_candidates_probabilities))
    
    def get_choice_probs(self, raw_token_probs, num_choices):
        choices = list(string.ascii_uppercase[:num_choices])
        choice_probs = {}

        for token, prob in raw_token_probs:
            if token in choices:
                choice_probs[token] = prob

        # Normalize    
        total = sum(choice_probs.values())  
        for c in choices:
            choice_probs[c] /= total
        return choice_probs

def oqa_tnp_dataset():
    steer_groups = ['Northeast', 'Conservative', 'South', 'Male', 'College graduate/some postgrad', 'White', 
            'Black', 'Moderate', 'Republican', 'Hispanic', 'Hindu', 'Atheist', 'Liberal',
            'Less than $30,000', 'Jewish', 'Asian', 'Female', 'Less than high school', 'Democrat',
            'Muslim', '$100,000 or more', 'Protestant']
    
    
    # use 500 controversial questions
    human_df_path = 'Pew_American_Trends_Panel_disagreement_500_default_human.csv'
    human_df  = pd.read_csv(human_df_path)
    human_df = human_df[human_df['group'].isin(steer_groups)]

    grouped = human_df.groupby('qkey').size()
    qkeys_less_than_22 = grouped[grouped < 22].index
    human_df = human_df[~human_df['qkey'].isin(qkeys_less_than_22)]    
    model_df_path = 'Pew_American_Trends_Panel_disagreement_500_default_model.csv'
    model_df = pd.read_csv(model_df_path)
    model_df = model_df.drop_duplicates(subset='qkey')

    

    merged_df = human_df.merge(model_df[['qkey', 'question_raw', 'question', 'references', 'mapping', 'ordinal', 'ordinal_refs', 'refusal_refs']], on='qkey', how='left')
    merged_df['ordinal_refs'] = merged_df['ordinal_refs'].apply(ast.literal_eval)
    merged_df['refusal_refs'] = merged_df['refusal_refs'].apply(ast.literal_eval)
    merged_df['mapping'] = merged_df['mapping'].apply(ast.literal_eval)

    def get_prompt(row):
        instruction = "Answer the following question by picking from the given options"
        input_text = "{question}\n\nOptions:{options}".format(
            question=row['question_raw'], options=get_options_str(list(row['ordinal_refs'])))
        prompt = get_alpaca_prompt(instruction=instruction, input_text=input_text)
        
        # text_per_ansewr = {}
        # mappings = list(row['mapping'].items())
        # for idx, i in enumerate(list(row['ordinal_refs'])):
        #     key, item = mappings[idx]
        #     text_per_ansewr[i] = prompt + key+ '. '+ item
        return prompt
    merged_df['prompt'] = merged_df.apply(get_prompt, axis=1)
    merged_df.to_csv('Pew_American_Trends_Panel_disagreement_500_human_tnp_dataframe.csv', index=False)
    
    df = merged_df
    sum_options = 0
    for idx, row in df.iterrows():
        sum_options += len(row['mapping'].items()) - len(row['refusal_refs'])
        probs = ast.literal_eval(row['D_H'])
        assert len(probs) == len(row['mapping'].items()) - len(row['refusal_refs'])

    def expand_rows(df):
        rows = []
        for idx, row in df.iterrows():
            prob_list = ast.literal_eval(row['D_H'])
            # Map option keys to indices (assuming the options are in alphabetical order)
            key_to_index = {key: i for i, key in enumerate(row['mapping'].keys())}
            for key, value in row['mapping'].items():
                if value in row['refusal_refs']:
                    continue
                new_row = row.copy()
                new_row['option_key'] = key
                new_row['option_value'] = value
                new_row['prompt_answer'] = row["prompt"] + "\n### Response:\n" + key + '.' + value
                new_row['prob_y'] = prob_list[key_to_index[key]]
                rows.append(new_row)
        return pd.DataFrame(rows)
    df_final = expand_rows(df)
    assert sum_options == len(df_final)
    print(sum_options, 'number of training examples')
    df_final.to_csv('Pew_American_Trends_Panel_disagreement_500_human_tnp_train.csv')

    return df_final

    



@hydra.main(config_path="configs", config_name="train")
def main(config: DictConfig) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_random_seed(config.seed)
    #df = oqa_tnp_dataset()
    df = pd.read_csv('Pew_American_Trends_Panel_disagreement_500_human_tnp_train.csv')
    group_counts = df.groupby('group').size()
    print(group_counts)
    
    model = llmodel(config)
    batch_size = 4
    embeddings = []

    for i in tqdm(range(0, len(df), batch_size)):
        if False:
            embeddings.extend(batch_embeddings.cpu().numpy().tolist())
        else:
            batch_sentences = df['prompt_answer'].iloc[i : i + batch_size].tolist()
            
            batch_embeddings = model.get_batch_sentence_embeddings(batch_sentences)
            print(batch_embeddings.shape)
            embeddings.extend(batch_embeddings.cpu().numpy().tolist())
    print(len(embeddings))
    df['embedding'] = embeddings
    df.to_pickle("TNP_OQA_with_alpaca_embeddings.pkl")

    return


if __name__ == "__main__":
    score = main()
    
