# Based on
# huggingface/notebooks/examples/language_modeling_from_scratch.ipynb

# Hugging Face imports
from datasets import load_dataset
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import TrainingArguments, Trainer
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch
import data.helpers as ph
import string
from peft import PeftModel
from data.constants import COUNTRIES, ALPHABET
from data.utils import get_alpaca_prompt, get_options_str
from termcolor import colored
import pandas as pd
pd.set_option('display.max_colwidth', None)
from scipy.stats import wasserstein_distance
import swifter
import json
from functools import partial
import os
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
import numpy as np
import hydra 
from omegaconf import DictConfig, OmegaConf
from datasets import (
    load_dataset, 
    Dataset,
)

# load opinionQA dataset
class llmodel:
    def __init__(self, config):
        print(f"Loading model from {config.model}")
        lora_weights = config.lora_weights
        tokenizer = AutoTokenizer.from_pretrained(config.model)
        model = AutoModelForCausalLM.from_pretrained(config.model, load_in_8bit=config.use_int8, torch_dtype=torch.float16)
        if config.use_finetune:
            print('------ using finetuned model -------')
            model = PeftModel.from_pretrained(model, lora_weights, torch_dtype=torch.float16)
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id
        


        if config.use_int8:
            model = prepare_model_for_int8_training(model)

        self.tokenizer = tokenizer 
        self.model =  model

    def get_predictions(self, sentence):
        # Encode the sentence using the tokenizer and return the model predictions.
        inputs = self.tokenizer.encode(sentence, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(inputs)
            predictions = outputs[0]
        return predictions

    def get_next_word_probabilities(self, sentence, top_k=500):

        # Get the model predictions for the sentence.
        predictions = self.get_predictions(sentence)
        
        # Get the next token candidates.
        next_token_candidates_tensor = predictions[0, -1, :]

        # Get the top k next token candidates.
        topk_candidates_indexes = torch.topk(
            next_token_candidates_tensor, top_k).indices.tolist()

        # Get the token probabilities for all candidates.
        all_candidates_probabilities = torch.nn.functional.softmax(
            next_token_candidates_tensor, dim=-1)
        
        # Filter the token probabilities for the top k candidates.
        topk_candidates_probabilities = \
            all_candidates_probabilities[topk_candidates_indexes].tolist()
        

        # Decode the top k candidates back to words.
        topk_candidates_tokens = \
            [self.tokenizer.decode([idx]).strip() for idx in topk_candidates_indexes]

        # Return the top k candidates and their probabilities.

        return list(zip(topk_candidates_tokens, topk_candidates_probabilities))
    
    def get_choice_probs(self, raw_token_probs, num_choices):
        choices = list(string.ascii_uppercase[:num_choices])
        choice_probs = {}

        for token, prob in raw_token_probs:
            if token in choices:
                choice_probs[token] = prob

        # Normalize    
        total = sum(choice_probs.values())  
        for c in choices:
            choice_probs[c] /= total
        return choice_probs

def get_dataset_oqa(attribute, group, PEW_SURVEY_LIST, CONTEXT):
    print('group:', attribute, group)

    oqa_datasets = {
        "question" : [],
        "selections" : [],
        "options" : [],
        "ordinal": [],
        "d_m_other": [],

    }
    # load dataset
    DATASET_DIR = 'human_resp/'
    RESULT_DIR = 'runs'
    OUTPUT_DIR = f'distributions'
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
    
    if CONTEXT == "default":
        SURVEY_LIST = [f'American_Trends_Panel_W{SURVEY_WAVE}' for SURVEY_WAVE in PEW_SURVEY_LIST] + \
                    ['Pew_American_Trends_Panel_disagreement_500']
    else:
        SURVEY_LIST = ['Pew_American_Trends_Panel_disagreement_500']
        steer_df = pd.read_csv(f'{CONTEXT}.csv',
                        delimiter='\t')

    for SURVEY_NAME in SURVEY_LIST:
        print(colored(SURVEY_NAME, "red"))
        RESULT_FILES = [f for f in os.listdir(RESULT_DIR) if SURVEY_NAME in f and f'context=default' in f]
        
        ## Read survey info, questions and options
        info_df = pd.read_csv(os.path.join(DATASET_DIR, SURVEY_NAME, 'info.csv'))
        info_df['option_ordinal'] = info_df.apply(lambda x: eval(x['option_ordinal']), axis=1)
        info_df['references'] = info_df.apply(lambda x: eval(x['references']), axis=1)
        
        ## Load model and human responses
        md_df = pd.read_csv(os.path.join(DATASET_DIR, SURVEY_NAME, 'metadata.csv'))
        md_df['options'] = md_df.apply(lambda x: eval(x['options']), axis=1)
        md_order = {'Overall': {'Overall': 0}}
        md_order.update({k: {o: oi for oi, o in enumerate(opts)} for k, opts in zip(md_df['key'], md_df['options'])})
        
        ## Get model opinion distribution
        #print(colored('--Getting LM opinion distribution--', 'blue'))
        model_df = ph.get_model_opinions(RESULT_DIR, RESULT_FILES, info_df)
    
        ## Get human opinion distribution
        #print(colored('--Getting human opinion distribution--', 'blue'))
        if SURVEY_NAME != "Pew_American_Trends_Panel_disagreement_500":
            resp_df = pd.read_csv(os.path.join(DATASET_DIR, SURVEY_NAME, 'responses.csv'))
            
            
            
            human_df = pd.concat([ph.extract_human_opinions(resp_df, 
                                                            model_df, 
                                                            md_df, 
                                                            demographic=demographic, 
                                                            wave=int(SURVEY_NAME.split('_W')[1]))
                    for demographic in ph.DEMOGRAPHIC_ATTRIBUTES])
            
        else:
        

            human_df = []
            for wave in PEW_SURVEY_LIST:
                sn = f'American_Trends_Panel_W{wave}'
                hdf = pd.read_csv(os.path.join(OUTPUT_DIR, f'{sn}_default_human.csv'))
                idf = info_df[info_df['survey'] == f'Pew_{sn}']
                hdf = hdf[hdf['qkey'].isin(idf['key'].values)]
                human_df.append(hdf)
            human_df = pd.concat(human_df)
            human_df['D_H'] = human_df.apply(lambda x: [float(f) for f in x['D_H'][1:-1].strip().split(' ') if len(f)], axis=1)
            
                
        ## Combine and save
        combined_df = pd.merge(model_df, human_df)
        combined_df['group_order'] = combined_df.apply(lambda x: md_order[x['attribute']][x['group']], axis=1)
        combined_df = combined_df[(combined_df['attribute'] == attribute) & (combined_df['group'] == group)]
        if CONTEXT != 'default':
            steer_df = pd.read_csv(f'{CONTEXT}.csv',
                        delimiter='\t')
            context_input = steer_df.loc[steer_df['subgroup'] == group, 'question'].values[0]
            if CONTEXT == 'steer-qa':
                context_input += '\n' + group
        for i, r in combined_df.iterrows():
            questions = r['question_raw']
            if CONTEXT != 'default':
                # add optional context
                questions = context_input + '\n' + questions
            options = r['ordinal_refs']
            ordinal = r['ordinal']
            selections = list(r['D_H'])
            assert len(selections) == len(options)
            oqa_datasets['question'].append(questions)
            oqa_datasets['selections'].append(selections)
            oqa_datasets['options'].append(options)
            oqa_datasets['ordinal'].append(ordinal)
            oqa_datasets["d_m_other"].append(r['D_M'])
    return Dataset.from_dict(oqa_datasets)


@hydra.main(config_path="configs", config_name="train")
def main(config: DictConfig) -> None:
    # get test dataset
    if config.data.task == "anthropic_global_opinions":
        # ds = get_dataset_Global(config.data.group)
        # TODO: add the ordinal column for options
        pass 
    elif config.data.task == "opinion_qa": 
        ds = get_dataset_oqa(config.data.attribute, config.data.group, config.data.PEW_SURVEY_LIST, config.data.CONTEXT)
    else:
        raise ValueError(f"Unknown task {config.data.task}")
    

    ds = ds.train_test_split(test_size=config.data.test_split, seed=config.seed)
    print("Train dataset size:", len(ds["train"]))
    print("Test dataset size:", len(ds["test"]))
    ds_df = pd.DataFrame(ds["test"])
    
    model = llmodel(config)

    # Initialize an empty list to hold D_m values
    dm_values = []
    
    for i, r in ds_df.iterrows():
        # TODO: batchify this
        question = r['question']
        choices = r['selections']
        options = r['options']
        ordinal = r['ordinal']
        prompt = f"Question: {question}\n"

        instruction = "Answer the following question by picking from the given options"
        
           
        input_text = "{question}\n\nOptions:{options}".format(
            question=question, options=get_options_str(options))
        
        prompt = get_alpaca_prompt(instruction=instruction, input_text=input_text)

        word_probs = model.get_next_word_probabilities(prompt) 
        num_choices = len(ordinal)
        D_m = model.get_choice_probs(word_probs, num_choices)
        
        dm_values.append(list(D_m.values()))

        assert len(list(D_m.values())) == len(r['selections'])
    ds_df['d_m'] = dm_values

    ds_df['WD'] = ds_df.swifter.apply(lambda x: wasserstein_distance(x['ordinal'], 
                                                                         x['ordinal'],
                                                                         x['d_m'], x['selections']) / ph.get_max_wd(x['ordinal']), 
                                          axis=1)

    ds_df['Rep'] = 1 - ds_df['WD']
    alignment_score = ds_df['Rep'].mean()
    print(alignment_score)
    return alignment_score

if __name__ == "__main__":
    score = main()
    
