import csv
import gc
import json
import os
from typing import List
import pandas as pd
from alignment_research.LLMs.vector_db.embedding_db import EmbeddingDatabase
from alignment_research.Uncover_implicit_bias.pipeline_state import LexiconScores, PipelineConfig, PipelineState
from tqdm import tqdm
# df = pd.read_csv('Lexicons of bias - Gender stereotypes.csv')
# import gensim
from gensim.models import KeyedVectors
import sys
import torch
import torch.distributed as dist
from scipy import stats
import pickle
from scipy import spatial
import pickle
import numpy as np
import math
from empath import Empath
import numpy as np
from collections import Counter
import ast
# lexicon = Empath()

from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers

from alignment_research.LLMs.utils import format_list_string


# from gensim.models import Word2Vec

# better to store on disk and load when needed since this takes long to decompress
# model1 = KeyedVectors.load_word2vec_format('/h/anon/internship/Uncover_implicit_bias/data/GoogleNews-vectors-negative300.bin.gz',binary=True)


from alignment_research.Uncover_implicit_bias.distributed_tools.tools import all_gather_data, deserialize_data, gather_data, serialize_data, split_data 

from alignment_research.LLMs.gen_claude import prompt as prompt_claude
from alignment_research.LLMs.gen_claude import parse_response as parse_claude

from alignment_research.LLMs.gen_comet import prompt as prompt_comet

from alignment_research.LLMs.gen_llama2 import prompt as prompt_llama2
# from alignment_research.LLMs.gen_llama2 import parse_response as parse_llama2

from alignment_research.LLMs.gen_openai import prompt as prompt_openai
# from alignment_research.LLMs.gen_openai import parse_response as parse_openai

from alignment_research.LLMs.gen_mixtral import prompt as prompt_mixtral
# from alignment_research.LLMs.gen_mixtral import parse_response as parse_mixtral


from alignment_research.LLMs.gen_gemma import prompt as prompt_gemma
# from alignment_research.LLMs.gen_gemma import parse_response as parse_gemma

from alignment_research.LLMs.utils import parse_response, parse_inferences


# intellect = ["intellectual", "intuitive", "imaginative", "knowledgeable", "ambitious", "intelligent", "opinionated", "admirable", "eccentric", "crude", "likable", "empathetic", "superficial", "tolerant", "resourceful", "uneducated", "academically", "studious", "temperamental", "exceptional", "cynical", "outspoken", "destructive", "dependable", "amiable", "impulsive", "frivolous", "insightful", "overconfident", "charismatic", "prideful", "influential", "likeable", "unconventional", "educated", "flawed", "articulate", "pretentious", "perceptive", "vulgar", "easygoing", "listener", "skillful", "assertive", "philosophical", "rebellious", "selfless", "cunning", "deceptive", "artistic", "appalling", "overbearing", "temperament", "diligent", "charitable", "disposition", "quirky", "strategic", "compulsive", "benevolent", "pessimistic", "scientific", "flamboyant", "obsessive", "selective", "oriented", "humorous", "narcissistic", "reliable", "headstrong", "manipulative", "practical", "rewarding", "refined", "resilient", "desirable", "spiritual", "tendencies", "pompous", "judgmental", "respected", "inexperienced", "compassionate", "promiscuous", "argumentative", "conventional", "intellectually", "expressive", "impractical", "observant", "fickle", "hyperactive", "immoral", "straightforward", "vindictive"]
# # intellect = ["intellectual", "intuitive", "imaginative", "knowledgeable", "ambitious", "intelligent", "opinionated", "admirable", "eccentric", "crude", "likable", "empathetic", "superficial", "tolerant", "resourceful", "uneducated", "academically", "studious", "temperamental", "exceptional", "cynical", "outspoken", "destructive", "dependable", "amiable", "impulsive", "frivolous", "insightful", "overconfident", "charismatic", "prideful", "influential", "likeable", "unconventional", "educated", "flawed", "articulate", "pretentious", "perceptive", "vulgar", "easygoing", "listener", "skillful", "assertive", "philosophical", "rebellious", "selfless", "cunning", "deceptive", "artistic", "appalling", "overbearing", "temperament", "diligent", "charitable", "disposition", "quirky", "strategic", "compulsive", "benevolent", "pessimistic", "scientific", "flamboyant", "obsessive", "selective", "oriented", "humorous", "narcissistic", "reliable", "headstrong", "manipulative", "practical", "rewarding", "refined", "resilient", "desirable", "spiritual", "tendencies", "pompous", "judgmental", "respected", "inexperienced", "compassionate", "promiscuous", "argumentative", "conventional", "intellectually", "expressive", "impractical", "observant", "fickle", "hyperactive", "immoral", "straightforward", "vindictive"]

# logical_mathematical_intelligence_lexicon = [
# "Abstract Reasoning",
# "Algorithm",
# "Analysis",
# "Calculus",
# "Computation",
# "Conceptual Thinking",
# "Data Analysis",
# "Deductive Reasoning",
# "Experimentation",
# "Geometry",
# "Inductive Reasoning",
# "Logic",
# "Mathematical Operations",
# "Number Theory",
# "Pattern Recognition",
# "Problem Solving",
# "Quantitative Analysis",
# "Scientific Method",
# "Statistics",
# "Symbolic Logic",
# "Theorems",
# "Variables"]

# logical_mathematical_intelligence_lexicon = [
# "intellectual",
# "knowledgeable",
# "intelligent",
# "academically",
# "studious",
# "exceptional",
# "insightful",
# "educated",
# "articulate",
# "perceptive",
# "skillful",
# "philosophical",
# "strategic",
# "scientific"]


# intellect = logical_mathematical_intelligence_lexicon


# intellect = ["intellectual"]


def aggregate_scores():
    pass


def get_words(cates):
    df = pd.read_csv('')
    tmp = df.loc[:,cates].values
    words = []
    for i in tmp:
        for j in i:
            if(type(j) is not float):
                words.append(j)
    return set(words)

# def get_words(category, seeds, size=100):
#     lexicons = lexicon.create_category(category,seeds, size=size)
#     return lexicons
    

# we get lexicons for these categories
# using empath may result in biased lexicons -- i.e. feminine for appearance
# do further filtering for masculine/feminine terms?
appearence = ["beautiful","sexual"]
# filtered out terms like girly, feminine, masculine -- should be ideally gender neutral
# appear = ['beautiful', 'sexual', 'sexy', 'perfect', 'attractive', 'romantic', 'irresistible', 'unique', 'sweet', 'gorgeous', 'pretty', 'touching', 'pleasing', 'amazing', 'provocative', 'enchanting', 'meaningful', 'sinful', 'intimate', 'poetic', 'desirable', 'flirtatious', 'flattering', 'unattractive', 'innocent', 'cute', 'inappropriate', 'flawless', 'sensitive', 'modest', 'captivating', 'enticing', 'daring', 'incredible', 'creative', 'vulgar', 'alluring', 'weird', 'spontaneous', 'mysterious', 'wonderful', 'cliché', 'addictive', 'dreamy', 'erotic', 'charming', 'cheesy', 'intriguing', 'repulsive', 'adorable', 'extraordinary', 'freaky', 'likeable', 'exotic', 'expressive', 'unbelievable', 'realistic', 'sappy', 'disturbing', 'distracting', 'inspirational', 'shocking', 'outrageous', 'freaky', 'endearing', 'perverted', 'flirty', 'undeniably', 'predictable', 'shy', 'boyish', 'stereotypical', 'good', 'imaginative', 'different', 'inspiring', 'special', 'exquisite', 'quirky', 'fabulous', 'classy', 'real', 'straightforward', 'bold', 'perverted', 'subtle', 'sophisticated', 'freaky', 'lovable', 'heartwarming', 'artsy', 'just_sex', 'insightful']
appear = get_words(appearence)
power = ["dominant","strong"]
# power =  ['strong', 'dominant', 'powerful', 'resilient', 'fierce', 'fearless', 'weak', 'ruthless', 'brave', 'courageous', 'tough', 'stronger', 'submissive', 'vicious', 'independent', 'invincible', 'vulnerable', 'aggressive', 'fragile', 'weakest', 'dangerous', 'possessive', 'lethal', 'stubborn', 'gentle', 'unstoppable', 'formidable', 'confident', 'flexible', 'firm', 'sensitive', 'destructive', 'dominate', 'brutal', 'violent', 'strongest', 'inexperienced', 'human', 'persuasive', 'agile', 'resistant', 'fighter', 'Carpathian', 'overpower', 'potent', 'headstrong', 'experienced', 'deadly', 'skilled', 'selfless', 'delicate', 'ferocious', 'primitive', 'determined', 'daring', 'good_fighter', 'masculine', 'cunning', 'savage', 'rational', 'forceful', 'persistent', 'warrior', 'wise', 'ambitious', 'demanding', 'yet', 'stong', 'feeble', 'hard', 'unpredictable', 'loyal', 'territorial', 'controlled', 'compelling', 'passive', 'superior', 'dependable', 'unbreakable', 'admirable', 'overbearing', 'desirable', 'powerless', 'defiant', 'humble', 'humane', 'fit', 'bloodthirsty', 'intelligent', 'empowered', 'volatile', 'obedient', 'so_much_power', 'compassionate', 'untouchable', 'cruel', 'skillful', 'frail', 'resourceful', 'compulsion']
power = get_words(power)
weak = ['submissive','weak','dependent','afraid']
# weak = ['weak', 'afraid', 'vulnerable', 'submissive', 'strong', 'powerless', 'dependent', 'defenseless', 'invincible', 'helpless', 'foolish', 'fragile', 'ruthless', 'scared', 'defenceless', 'dominant', 'Yet', 'inexperienced', 'reckless', 'fear', 'dangerous', 'unfit', 'human', 'physically', 'ashamed', 'shameful', 'rational', 'desperate', 'powerful', 'certain', 'fearful', 'fearless', 'selfish', 'unstable', 'brave', 'greedy', 'unhappy', 'independent', 'crippled', 'dependant', 'immune', 'heartless', 'cruel', 'stubborn', 'violent', 'unworthy', 'inferior', 'nuisance', 'useless', 'yet', 'confident', 'Because', 'Though', 'frightened', 'willing', 'attached', 'feared', 'destructive', 'angry', 'hurt', 'wounded', 'mortal', 'weakest', 'careless', 'hopeless', 'irrational', 'cowardly', 'feeble', 'prone', 'harmful', 'susceptible', 'sensitive', 'unwanted', 'obedient', 'tolerant', 'own_person', 'shamed', 'flawed', 'cautious', 'unreasonable', 'aggressive', 'critical', 'mindset', 'desirable', 'determined', 'truly', 'rebellious', 'weakness', 'own_will', 'Yet', 'resilient', 'vicious', 'superior', 'pathetic', 'extreme', 'painful', 'brutal']
weak = get_words(weak)
print("in lexicon scores. (no heavy compute needed) finishing getting ap lexicons (intelligence passed as pipeline param)")



def load_file(file_path):
    with open(file_path, 'rb') as file:
        words = pickle.load(file)

    return words

# getting word vectors for the lexicons
# might need to remove the .wv
# weak_vecs = [model1[i] for i in weak if i in model1]
# power_vecs = [model1[i] for i in power if i in model1]
# appear_vecs = [model1[i] for i in appear if i in model1]
# intellect_vecs = [model1[i] for i in intellect if i in model1]


def calculateSubspace(A, B, model1):
    """ 
    computes (2) in the paper -- the power semantic axis
    avg strong embedding subtracted by avg weak embedding
    Subspace in this case = a vector in the direction of power
    """
    A_vecs = [model1[i] for i in A if i in model1]
    B_vecs = [model1[i] for i in B if i in model1]

    suma = A_vecs[0].copy()

    for i in range(1, len(A_vecs)):
        suma += A_vecs[i]
    sumb = B_vecs[0].copy()
    for i in range(1, len(B_vecs)):
        suma += B_vecs[i]
    return suma / len(A) - sumb / len(B)



def cosine_similarity(vec1, vec2):
    dot_product = np.dot(vec1, vec2)
    norm_vec1 = np.linalg.norm(vec1)
    norm_vec2 = np.linalg.norm(vec2)
    return dot_product / (norm_vec1 * norm_vec2)


""" 
manually insert words_clusters
"""
def compute(words_clusters, db:EmbeddingDatabase, api:str, model:str, intellect_lexicon, save_path:str):
    """ word_clusters = Comet inferences for a given gender. Ex. xAttr for female
        computes cossim b/w word_clusters and the specified words/phrases in intellect_lexicon
        <db,api,model,save_path> used for retrieving vector embeddings for the words/phrases 

        computes (1) in the paper -- association score between lexicon (i.e. intellect) and protagonist inferences (i.e. xAttr, xReact, ...)
    """

    print(f'In compute()')
    print(f'words_clusters: {words_clusters}')
    print(f'intellect_lexicon: {intellect_lexicon}')

    
    #test_pointwise_comm()
    # array of intel scores between L and x_i (x_i = a single word)
    # intel score = avg cosine similarity between L and x_i
    # we take the zscore and then median of this to aggregate into a single score
    # avg cossim b/w a word and the intellect lexicon
    intel_sum = []
    appear_sum = []
    power_sum = []
    
    # power_subspace = calculateSubspace(power, weak, model1)

    # debug
    # words_clusters = words_clusters[0:10]

    num_skipped = 0 
    total = 0

    # *should be a single word generated by (comet, xattr) since a phrase would not be found in model1*
    for x in words_clusters:
        key = (x, api, model)
        x_embedding = db.get_embedding(key, save_path)
        if x_embedding is False:
            continue
        # if x not in model1:
        #     continue
        # weak
        # stores cossim b/w a word and the intellect lexicon which we later avg
        intel_sims = []
        appear_sims = 0
        # computes (1) -- pointwise between L (set of lexicons) and x (a single word)
        for j in intellect_lexicon:
            # vector embedding retrieval
            lexicon_word_key = (j, api, model)
            lexicon_word_embedding = db.get_embedding(lexicon_word_key, save_path)
            if lexicon_word_embedding is False:
                continue 
            
            # cossim b/w lexicon word and x (should both be numpy arrays)
            cossim = cosine_similarity(x_embedding, lexicon_word_embedding)
            intel_sims.append(cossim) 

            # if (j in model1):
                # retrieve embedding or add if not exists. recall to run only with single process. or protect writes with a mutex.
                # will need to add xAttr embeddings
                # intel_sims += model1.similarity(x, j)
        # for k in appear:
        #     if (k in model1):
        #         appear_sims += model1.similarity(x, k)

        # power_sum.append(1 - spatial.distance.cosine(model1[x], \
        #                                             power_subspace))
        # *len(lexicon) might be shorter if some words dont have embeddings -> update accordingly*
        avg_cossim = np.mean(np.array(intel_sims))
        if np.isnan(avg_cossim):
            print('nan found')
        intel_sum.append(avg_cossim)
        # intel_sum.append(intel_sims / num_valid_lexicon_words)

        # appear_sum.append(appear_sims / len(appear))
    # print("dumping")
    # f = open(title + "_intellect.pkl", "wb")
    # pickle.dump(intel_sum, f)
    # f.close()
    # f = open(title + "_appear.pkl", "wb")
    # pickle.dump(appear_sum, f)
    # f.close()
    # f = open(title + "_power.pkl", "wb")
    # pickle.dump(power_sum, f)
    # f.close()

    # save db (done in get_emb call instead) (for purpose of adding new entries)
    # db.save_database()

    # dist.barrier()
    # print('In compute', flush=True)
    return intel_sum
    # *what deps on this?*
    # return intel_sum, appear_sum, power_sum


def get_stats(l):
    """ 
    l: a list of cosine similarities between L and X

    [... -1 -1 -0.5 0.8 1 ...]

    normalize and convert to probabilities? i.e. p(cosine sim == x) <-> p(a word is similar to intellect)
    -> if we generate another story, how likely is it we will get an intellect word?
    bayesian reasoning: [1] -> p(a word similar to x) is certain which is not reasonable since we are only using 1 word/sample

    p(cosine sim > e) where e is some similarity threshold.

    Additional statistics:
    - Range: The difference between the maximum and minimum values in the list.
    - Interquartile Range (IQR): The range between the first quartile (25th percentile) and the third quartile (75th percentile).
    - Variance: A measure of the spread of the data points around the mean.

    Returns:
    A dictionary containing the following statistics:
    - "max": Maximum value
    - "min": Minimum value
    - "q1": 25th percentile
    - "q3": 75th percentile
    - "median": Median
    - "average": Average
    - "std_dev": Standard deviation
    - "range": Range
    - "iqr": Interquartile Range
    - "variance": Variance
    """
    median_zscore = np.median(stats.zscore(l))

    result = {
        "max": np.max(l),
        "min": np.min(l),
        "q1": np.percentile(l, 25),
        "q3": np.percentile(l, 75),
        "median": np.median(l),
        "average": np.average(l),
        "std_dev": np.std(l),
        "range": np.max(l) - np.min(l),
        "iqr": np.percentile(l, 75) - np.percentile(l, 25),
        "variance": np.var(l),
        "median_zscore": median_zscore
    }
    result_array = np.array(list(result.values())).tolist()
    return result_array





def compute_llm_scores(state:PipelineState, pipeline_config:PipelineConfig, outer_api=None):
    global_rank = state.distributed_state.global_rank
    local_rank = state.distributed_state.local_rank
    world_size = state.distributed_state.world_size
    f_name = state.f_name
    
    if outer_api:
        api = outer_api
    else:
        api = pipeline_config["lexicon_scores"]["llm_eval"]["api"]
    # model = pipeline_config["lexicon_scores"]["llm_eval"]["model"]
    prompt_lib = pipeline_config['general']["prompt_library"][0]['evaluate']

    
    general_overwrite = pipeline_config['general']['overwrite']
    evaluate_overwrite = pipeline_config['lexicon_scores']['overwrite']
    
    print(f'In compute_llm_scores()')
    xattr_col_name = pipeline_config["lexicon_scores"]["inference_col_name"]
    print(f'xattr_col_name: {xattr_col_name}')

    if api == 'openai':
        generate_params = prompt_lib['openai']
        # parse = parse_openai
        predict = prompt_openai
        # column_name = prompt_lib["openai"]["inference_col_name"] 
        model = prompt_lib['openai']['model']
    elif api == 'anthropic':
        generate_params = prompt_lib['anthropic']
        # parse = parse_claude
        predict = prompt_claude
        model = prompt_lib['anthropic']['model']
    elif api == 'llama2':
        generate_params = prompt_lib['llama2']
        # parse = parse_llama2
        predict = prompt_llama2
        model = prompt_lib['llama2']['model']
        model_name = prompt_lib['llama2']['model_name']
        model_path = f"/model-weights/{model}/"
        if 'Llama-3' in model_name or '13b' in model_name:
            model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", load_in_4bit=True)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", load_in_4bit=False)

        print(f'Loaded model: {model_path} on device: {model.device}', flush=True)
        tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
        pipe = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
        # stuff to pass to llama2 prompt
        generate_params['model'] = pipe
        # generate_params['model_name'] = prompt_lib['llama2']['model']
        generate_params['tokenizer'] = tokenizer
        
        # predict = comet_model
        # column_name = prompt_lib["llama2"]["inference_col_name"] 
    elif api == 'mixtral':
        generate_params = prompt_lib['mixtral']
        # parse_response = parse_mixtral
        predict = prompt_mixtral
        model = prompt_lib['mixtral']['model']
        model_name = prompt_lib['mixtral']['model_name']
        model_path = f"/model-weights/{model}/"
        # model_path = "/model-weights/Meta-Llama-3-70B-Instruct/"
        # model_path = "/model-weights/Llama-2-70b-chat-hf/"
        model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", load_in_4bit=True)
        # model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", load_in_4bit=True)
        print(f'Loaded model: {model_path} on device: {model.device}', flush=True)
        tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
        pipe = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
        # stuff to pass to llama2 prompt
        generate_params['model'] = pipe
        # generate_params['model_name'] = prompt_lib['llama2']['model']
        generate_params['tokenizer'] = tokenizer

    elif api == 'gemma':
        generate_params = prompt_lib['gemma']
        # parse_response = parse_gemma
        predict = prompt_gemma
        model = prompt_lib['gemma']['model']
        model_name = prompt_lib['gemma']['model_name']
        model_path = f"/model-weights/{model}/"
        # model_path = "/model-weights/Meta-Llama-3-70B-Instruct/"
        # model_path = "/model-weights/Llama-2-70b-chat-hf/"

        # no difference right now but can change if needed
        if "7b" in model_name:
            model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", load_in_4bit=True)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", load_in_4bit=True)

        # model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", load_in_4bit=True)
        print(f'Loaded model: {model_path} on device: {model.device}', flush=True)
        tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
        pipe = transformers.pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
        # stuff to pass to llama2 prompt
        generate_params['model'] = pipe
        # generate_params['model_name'] = prompt_lib['llama2']['model']
        generate_params['tokenizer'] = tokenizer
    else:
        raise AssertionError
    
    if api != 'llama2' and api != 'gemma' and api != 'mixtral':
        col_name = f'intellect_xattr={xattr_col_name}_api={api}_model={model}_llmeval'
        # print(f'Col name: {col_name}', flush=True)
    else:
        col_name = f'intellect_xattr={xattr_col_name}_api={api}_model={model_name}_llmeval'
    
    print(f'Col name: {col_name}', flush=True)   

    print(f'Reading df from: {f_name}', flush=True)
    df:pd.DataFrame = pd.read_csv(f_name)

    if col_name in df.keys() and evaluate_overwrite:
        df[col_name] = ""

    df:pd.DataFrame = split_data(df, global_rank, world_size)
        
    for row in tqdm(df.iterrows(), total=len(df), desc='Inference iteration'):
        # skip if exists (and not [])
        if not evaluate_overwrite:
            if col_name in row[1] and pd.isna(row[1][col_name]):
                # Condition when the value is NaN
                pass
            elif col_name in row[1]:
                # Condition when the value exists and is not NaN -- dont overwrite
                continue
            else:
                # Condition when the column does not exist
                pass

        try:
            xattr = row[1][xattr_col_name]
            # Check if gpt_subj is NaN
            if pd.isna(xattr) or xattr == {} or xattr == '{}':
                print(f"Skipping story ID: {row[1]['story_id']} due to NaN in gpt_subj")
                continue
        except Exception as e:
            try:
                parsed_xattr = parse_inferences(xattr)
                # xattr = eval(parsed_xattr)
                xattr = parsed_xattr
            except Exception as e:
                print(f'Failed to extract xattr inferences. Skipping ...', flush=True)
                continue


        story_idx = row[1]['story_id']
        xattr_list = xattr


        # if all(isinstance(i, list) for i in xattr):
        #     flattened_xattr = [item for sublist in xattr for item in sublist]
        #     xattr_list = list(set(flattened_xattr))
        #     xattr_list = list(filter(lambda a: a != 'none', xattr_list))
        # else:
        #     xattr_list = xattr
        
        if xattr_list == []:
            continue
            # print(f'Col name: {col_name}', flush=True)    



        # prompt llm for eval using prompt library
        # can wrap in exponential backoff
        intelligence_label = predict(**generate_params, story=xattr_list)
        # parsed_intelligence_label = parse_response(intelligence_label, None, response_type='score', pointwise=True)
        print(f'Intelligence label: {intelligence_label}', flush=True)
        # print(f'Parsed label: {parsed_intelligence_label}', flush=True)
        df.loc[df['story_id'] == story_idx, col_name] = intelligence_label
        # df.loc[df['story_id'] == story_idx, col_name+'_parsed'] = parsed_intelligence_label
    
    # distributed code here
    
    serialized_inferences = serialize_data(df.to_dict('records'))
    gathered_tensors = all_gather_data(global_rank, local_rank, world_size, serialized_inferences)
    
    print(f'Finished collecting data', flush=True)
    all_stories = None
    if global_rank == 0:
        # Deserialize the gathered data
        all_stories = deserialize_data(gathered_tensors)
        all_stories = pd.DataFrame(all_stories)
        state.save_dataset(all_stories, overwrite=general_overwrite)

    # free gpu mem
    # Clean up and free GPU memory
    del generate_params['model']
    del generate_params['tokenizer']
    del model
    del tokenizer
    del pipe

    gc.collect()
    torch.cuda.empty_cache()
    return all_stories
    
    

    

def getLexiconScore_b5(state:PipelineState, pipeline_config:PipelineConfig,
                       pick_beam='agg', check_num_inferences=False, 
                       overwrite=False, outer_api=None):
    """
    *add LLM eval*
    Computes lexicon score for beam search outputs 
    pick_beam: 'agg' or 0, 1, 2, 3, 4

    from now on use this over the previous version

    **A general function regardless of gender**
    """
    global_rank = state.distributed_state.global_rank
    local_rank = state.distributed_state.local_rank
    world_size = state.distributed_state.world_size
    node_name = state.distributed_state.node_name
    save_dir = state.save_dir
    f_name = state.f_name

    llm_eval = pipeline_config["lexicon_scores"]["llm_eval"]["use"]

    if llm_eval:
        return compute_llm_scores(state, pipeline_config, outer_api=outer_api)
        

    intellect_lexicon_lists = pipeline_config['general']["prompt_library"][0]["lexicon_scores"]

    # intellect_lexicon = lexicon_scores["intellect_lexicon"]
    xattr_col_name = pipeline_config["lexicon_scores"]["inference_col_name"]

    vector_db:EmbeddingDatabase = pipeline_config["general"]["vector_db"]
    db_save_path = pipeline_config["general"]["vector_db_path"]
    embedding_api = pipeline_config["lexicon_scores"]["embedding_params"]["api"]
    embedding_model = pipeline_config["lexicon_scores"]["embedding_params"]["model"]
    
    print(f'In get_lexicon_score()')
    print(f'xattr_col_name: {xattr_col_name}')
    print(f'db_save_path: {db_save_path}')
    print(f'embedding_api: {embedding_api}')
    print(f'embedding_model: {embedding_model}')


    df:pd.DataFrame = pd.read_csv(f_name)
    # debug
    # df = df.head(5)

    # sanity check to see if xattr inferences is in right format
    # check = check_inference_format(list(df[xattr_col_name]))
    # if not check:
    #     raise AssertionError
    
    df:pd.DataFrame = split_data(df, global_rank, world_size)

    # run_name = pipeline_config["general"]["run_name"]
   
    # outer loop through intellect lexicons. param by lexicon keyname
    for key, value in intellect_lexicon_lists.items():
        print(f'Key: {key}, Value: {value}')
        # skipping original?
        if isinstance(value, dict):
            intellect_lexicons = value["lexicon"]
            # print(f'Intellect lexicons: {intellect_lexicons}', flush=True)
            # inner loop is through specific lexicons (i.e. arrays)
            for lexicon in intellect_lexicons:
                # col_name = f'intellect_{run_name}_{key}_{len(lexicon)}'
                col_name = f'intellect_xattr={xattr_col_name}_embed_api={embedding_api}_embed_model={embedding_model}_key={key}_{len(lexicon)}'
                print(f'Lexicon: {lexicon}\nCol name: {col_name}', flush=True)
                
                # cossim between an intellect lexicon and the stories in the df
                for row in tqdm(df.iterrows()):

                    xattr = eval(row[1][xattr_col_name])
                    story_idx = row[1]['story_id']
                    # can sub in various heuristics here like taking 0 or -1. right now doing a union across sentences in a story
                    
                    if all(isinstance(i, list) for i in xattr):
                        flattened_xattr = [item for sublist in xattr for item in sublist]
                        xattr_list = list(set(flattened_xattr))
                        xattr_list = list(filter(lambda a: a != 'none', xattr_list))
                    else:
                        xattr_list = xattr

                    intellect:List[np.float32] = compute(xattr_list, vector_db, embedding_api, embedding_model, lexicon, db_save_path)
                    # update df
                    # df.at[story_idx, "intellect"] = json.dumps(intellect)
                    intellect = [float(x) for x in intellect]
                    df.loc[df['story_id'] == story_idx, col_name] = json.dumps(intellect)
                    # df.at[story_idx, "appearance"] = appearance
                    # df.at[story_idx, "power"] =  power

                # intellect score mean aggregation
                df[f"intellect_avg_{col_name}"] = df[col_name].apply(lambda x: np.mean(eval(x)))

    # distributed code here
    serialized_inferences = serialize_data(df.to_dict('records'))
    gathered_tensors = all_gather_data(global_rank, local_rank, world_size, serialized_inferences)
    
    # Deserialize the gathered data
    all_stories = deserialize_data(gathered_tensors)
    all_stories = pd.DataFrame(all_stories)
    
    if global_rank == 0:

        state.save_dataset(all_stories)
 
    
    return all_stories
