

import json
import sys
# from alignment_research.LLMs.vector_db.embedding_db import EmbeddingDatabase
from alignment_research.LLMs.vector_db.embedding_db import EmbeddingDatabase
import torch
from tqdm import tqdm
import torch.distributed as dist


from alignment_research.Uncover_implicit_bias.distributed_tools.setup import setup, set_transformers_seed, set_seed, cleanup, quick_test, get_slurm_job_id
# from alignment_research.Uncover_implicit_bias.distributed_tools.tools import deserialize_data, gather_data, serialize_data, split_data 
# from alignment_research.dataset.dataset_utils import split_dataset

# pipeline imports 
from ReplaceGender import process as replace_gender_process 
from findSubject import process as find_subject_process
from generate_inferences import getDict as generate_inferences_process 
from get_lexicon_scores import getLexiconScore_b5 as get_lexicon_scores_process 
from parse_inferences import parse
from gen_stories import gen_stories as generate_stories

from alignment_research.evaluation.evaluate import compute_classification_metrics
import gc 

from alignment_research.Uncover_implicit_bias.parse import parse as parse_df



# state imports
from pipeline_state import PipelineConfig, PipelineStage, PipelineState, DistributedState, ReplaceGenderState, InferencesState, LexiconScoresState
import pandas as pd
import pprint

def should_run_stage(resume_from_stage:PipelineStage, only_run:bool):
    """return a set of stages to be run"""
    if only_run:
        return {resume_from_stage}
    
    # run from the stage onwards

    all_stages = {PipelineStage.REPLACE_GENDER, PipelineStage.FIND_SUBJECT, 
                PipelineStage.INFERENCE, PipelineStage.LEXICON_SCORES, PipelineStage.EVALUATE}
    
    if resume_from_stage == PipelineStage.REPLACE_GENDER:
        return all_stages
    
    if resume_from_stage == PipelineStage.FIND_SUBJECT:
        return {PipelineStage.FIND_SUBJECT, PipelineStage.INFERENCE, PipelineStage.LEXICON_SCORES, PipelineStage.EVALUATE}
    
    if resume_from_stage == PipelineStage.INFERENCE:
        return {PipelineStage.INFERENCE, PipelineStage.LEXICON_SCORES, PipelineStage.EVALUATE}
        return all_stages.pop(PipelineStage.REPLACE_GENDER,  PipelineStage.FIND_SUBJECT)

    if resume_from_stage == PipelineStage.LEXICON_SCORES:
        return {PipelineStage.LEXICON_SCORES, PipelineStage.EVALUATE}
        return all_stages.pop(PipelineStage.REPLACE_GENDER, PipelineStage.FIND_SUBJECT, PipelineStage.INFERENCE)
    
    if resume_from_stage == PipelineStage.EVALUATE:
        return {PipelineStage.EVALUATE}
        return all_stages.pop(PipelineStage.REPLACE_GENDER, PipelineStage.FIND_SUBJECT, PipelineStage.INFERENCE, PipelineStage.LEXICON_SCORES)

    if resume_from_stage == PipelineStage.PARSE:
            return {PipelineStage.PARSE}

    if resume_from_stage == PipelineStage.STORY:
            return {PipelineStage.STORY}

def run_inference_process(pipeline_state, pipeline_config, stage, process_function):
    """ 
    stage should be in {evaluate, inference} and loading configs depend on the stage
    """
    pipeline_state.stage = stage
    assert stage in {PipelineStage.LEXICON_SCORES, PipelineStage.INFERENCE, PipelineStage.PARSE, PipelineStage.STORY}
    if stage == PipelineStage.LEXICON_SCORES:
        prompt_lib = pipeline_config['general']["prompt_library"][0]['evaluate']
    elif stage == PipelineStage.INFERENCE:
        prompt_lib = pipeline_config['general']["prompt_library"][0]['generate_inferences']
    elif stage == PipelineStage.PARSE:
        prompt_lib = pipeline_config['general']["prompt_library"][0]['parse']
    elif stage == PipelineStage.STORY:
        prompt_lib = pipeline_config['general']["prompt_library"][0]['story']
    else:
        raise AssertionError

    outer_api_loop = pipeline_config['general']['outer_api']
    

    if outer_api_loop:
        for outer_api in outer_api_loop:
            # api = pipeline_config[stage]["llm_eval"][outer_api]
            api = outer_api
            models = prompt_lib[api]['model']
            if isinstance(models, list):
                if api == 'llama2' or api == 'mixtral' or api == 'gemma':
                    model_names = prompt_lib[api]['model_name']
                    four_bits = prompt_lib[api]['4bit']
                    for model, name, bit in zip(models, model_names, four_bits):
                        prompt_lib[api]['model'] = model
                        prompt_lib[api]['model_name'] = name
                        prompt_lib[api]['4bit'] = bit
                        # loop through prompts
                        dist.barrier()
                        process_function(pipeline_state, pipeline_config, outer_api=api)

                        # clear mem
                        torch.cuda.empty_cache()
                        gc.collect()
                else:
                    for model in models:
                        prompt_lib[api]['model'] = model
                        dist.barrier()
                        process_function(pipeline_state, pipeline_config, outer_api=api)
                        torch.cuda.empty_cache()
                        gc.collect()
                        
            else:
                process_function(pipeline_state, pipeline_config, outer_api=api)
                torch.cuda.empty_cache()
                gc.collect()
    else:
        if stage == PipelineStage.LEXICON_SCORES:
            api = pipeline_config["lexicon_scores"]["llm_eval"]["api"]
        elif stage == PipelineStage.INFERENCE:
            api = pipeline_config["inference"]["api"]
        elif stage == PipelineStage.PARSE:
            api = pipeline_config["parse"]["api"]
        elif stage == PipelineStage.STORY:
            api = pipeline_config["story"]["api"]
        models = prompt_lib[api]['model']
        if isinstance(models, list):
            if api == 'llama2' or api == 'mixtral' or api == 'gemma':
                model_names = prompt_lib[api]['model_name']
                four_bits = prompt_lib[api]['4bit']
                for model, name, bit in zip(models, model_names, four_bits):
                    prompt_lib[api]['model'] = model
                    prompt_lib[api]['model_name'] = name
                    prompt_lib[api]['4bit'] = bit
                    dist.barrier()
                    process_function(pipeline_state, pipeline_config)
                    torch.cuda.empty_cache()
                    gc.collect()
            else:
                for model in models:
                    prompt_lib[api]['model'] = model
                    dist.barrier()
                    process_function(pipeline_state, pipeline_config)
                    torch.cuda.empty_cache()
                    gc.collect()
        else:
            process_function(pipeline_state, pipeline_config)
            torch.cuda.empty_cache()
            gc.collect()

    dist.barrier()


def run_pipeline(f_name, save_dir, resume_from_stage=None, run_name=None, 
                 only_run=False, debug=False, pipeline_config:PipelineConfig=None):
    """ 
    f_name: file pointing to dataset which pipeline will be run on
    save_dir: where to store augmented dataset

    resume_from_stage: by default we run pipeline from beginning. If resume_from_stage, we should specify a train df path to resume from
    and the corresponding stage to resume from.

    only_run: if resume_from_stage and not only_run, then start from that stage and run all the way till end. If
    only_run is False, then only run the specified stage.

    Preconds:
        if resume_from_stage, then the pipeline stages to run should have the necessary info from prior stages in f_name dataset    
            - can verify this by checking appropriate df cols
    
    Notes
        test when needed
        for production, turn this into a command line tool
    """



    global_rank, world_size, node_name, local_rank = setup()

    SEED = 42  # Replace with your desired seed
    set_seed(SEED)
    set_transformers_seed(SEED)
    slurm_jobid = get_slurm_job_id()

    do_evaluate = pipeline_config["general"]["evaluate"]

    try: 
        quick_test(local_rank, world_size)
        print('Waiting', flush=True)
        dist.barrier()

        # init state -- each process gets their own copy

        # make this shared across all processes. synchronized write. only global_rank=0 can write to this.
        # only write operations are updating dataset and pipeline stage. 
        # but each proc should have its own DistributedState
        distributed_state = DistributedState(global_rank, local_rank, world_size, node_name)
        pipeline_state = PipelineState(distributed_state, f_name, save_dir, 
                                       run_name=run_name, debug=debug, slurm_jobid=slurm_jobid)
        print(f'Printing pipeline initial state', flush=True)
        pprint.pprint(pipeline_state)

        print(f'Printing pipeline config:')
        pprint.pprint(pipeline_config)

        # run Stages 1-2 (sequential)
        # each stage modifies the df pointed to in pipeline_state
        # synchronization?
        dist.barrier()

        stages_to_run = should_run_stage(resume_from_stage, only_run)
        
        print("\n\n-----------------PIPELINE START-------------\n\n")

        if PipelineStage.REPLACE_GENDER in stages_to_run:
            print("\n\n-----------------REPLACE GENDER-------------\n\n")
            anon_df = replace_gender_process(pipeline_state) 
            print(f'Process {global_rank}. anon_df snippet:{anon_df.head(5)}')
            pipeline_state.stage = PipelineStage.REPLACE_GENDER
            dist.barrier()
        
        if PipelineStage.FIND_SUBJECT in stages_to_run:
            print("\n\n-----------------FIND SUBJECT-------------\n\n")
            pipeline_state.stage = PipelineStage.FIND_SUBJECT
            subj_obj_df = find_subject_process(pipeline_state) 
            dist.barrier()

        if PipelineStage.INFERENCE in stages_to_run:
            print("\n\n-----------------GENERATE INFERENCES-------------\n\n")
            if pipeline_config['inference']['parse']:
                parse(pipeline_state, pipeline_config) 
            else:
                run_inference_process(pipeline_state, pipeline_config, PipelineStage.INFERENCE, generate_inferences_process)

        if PipelineStage.LEXICON_SCORES in stages_to_run:
            print("\n\n-----------------GET LEXICON SCORES-------------\n\n")
            run_inference_process(pipeline_state, pipeline_config, PipelineStage.LEXICON_SCORES, get_lexicon_scores_process)

        if PipelineStage.PARSE in stages_to_run:
            print("\n\n-----------------PARSE-------------\n\n")
            run_inference_process(pipeline_state, pipeline_config, PipelineStage.PARSE, parse_df)

        if PipelineStage.STORY in stages_to_run:
            print("\n\n-----------------STORY-------------\n\n")
            run_inference_process(pipeline_state, pipeline_config, PipelineStage.STORY, generate_stories)


        # run Stage 3 = evaluation (and maybe eda?)
        # should be carried out by a single process

        if global_rank == 0 and do_evaluate:
            if PipelineStage.EVALUATE in stages_to_run:
                print("\n\n-----------------EVALUATE-------------\n\n")
                evaluation_params = pipeline_config["evaluation"]

                compute_classification_metrics(pipeline_state, column_name=evaluation_params['column_name'], 
                                               save_dir=evaluation_params['save_dir'], probs=evaluation_params['probs'],
                                               experiment_tracker_load_path=evaluation_params['experiment_tracker_load_path'], 
                                               run_name=evaluation_params['run_name'], save_run=evaluation_params['save_run'])
  
        else:
            print(f'Skipping evaluation', flush=True)

        print("\n\n-----------------PIPELINE END-------------\n\n")


    finally:
        if global_rank == 0:
            print(f'Process {global_rank}. Cleaning up', flush=True)
            cleanup()
        # Close the file and restore stdout and stderr
        sys.stdout.close()
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__


def main():
    """ 
    might have a pre-existing dataset to run pipeline on


    note that it may require user input for evaluate overwrite metrics dict
    """

    # make into json file for easier tracking
    # nice to have all llm prompts in one place
    prompt_lib_path = ""
    with open(prompt_lib_path, 'r') as file:
        prompt_library = json.load(file)
   

    db_path = ""
    vector_db = EmbeddingDatabase()
    vector_db.load_database(db_path)
    print(f'Vector DB at {db_path} loaded: {vector_db}')

    pipeline_config:PipelineConfig = {
        "general": {
            # whether to run evaluation
            "evaluate": False,
            # below will create a new df col
            "run_name": "comet",
            "prompt_library": prompt_library,
            "vector_db" : vector_db,
            "vector_db_path": db_path,
            # if False, dataset will be param by slurm jobid
            "overwrite": True,
            # whether to run an outer loop through apis
            # go to run_inference_process and specify apis to iterate over
            "outer_api": ["gemini"],
            # "outer_api": False,

            "parser": False,
            "llm_parser": False
        },
        "parse": {
            # whether to run evaluation
            "api": "openai",
            # "parse_col_name": "intellect_xattr_api=openai_model=gpt-3.5-turbo-1106_scheme=sentence",
            "parse_col_name": "intellect_xattr_api=openai_model=gpt-3.5-turbo-1106_scheme=story",
            "overwrite": False
         
        },
        "inference": {
            "parse": False,
            "parse_col":"intellect_xattr_api=comet_model=word2vec",
            
            # for non-outer loop across api but inner loop across api models
            "api": "gemini",
            # set to None, not False if omitting. limit per process
            "limit": None,
            # overwrite filled-in inferences
            "overwrite": False,

            # put subj name here
            # scheme in {story, sentence, e2e}
            # "scheme": "e2e",
            "scheme": "e2e",
            # {cot, messages}
            # "msg_scheme": ["messages", "cot", "cot2"],
            # "msg_scheme": ["messages"],
            "msg_scheme": ["cot-bias"],
            "num_generations": 1, 
            # "story_col_name": "modified_story_llm"
            # "story_col_name": "story"
            "story_col_name": 'modified_story_llm',
            "gpt_subj_name": "modified_story_llm_subj"
        },
        "lexicon_scores": {

            "llm_eval": {
                # use in-replacement for cossim b/w embeddings metric
                "use": True, 
                "api": "openai", 
                
                # unused
                # "model": "gpt3.5"
            },
            "overwrite": False,

            # actually for generate_inferences, not lexicon_scores. change later if needed.
            # unused?
            "api": "openai",
            # inference column to take embeddings of? maybe change to a list?

            "inference_col_name": "intellect_xattr_api=gemini_model=gemini-1.5-flash_scheme=sentence_msg_scheme=messages_numgen=1_modified_story_llm",
            # general how to prompt api 
            # "intellect_lexicon": ["intellectual", "intuitive", "imaginative", "knowledgeable", "ambitious", "intelligent", "opinionated", "admirable", "eccentric", "crude", "likable", "empathetic", "superficial", "tolerant", "resourceful", "uneducated", "academically", "studious", "temperamental", "exceptional", "cynical", "outspoken", "destructive", "dependable", "amiable", "impulsive", "frivolous", "insightful", "overconfident", "charismatic", "prideful", "influential", "likeable", "unconventional", "educated", "flawed", "articulate", "pretentious", "perceptive", "vulgar", "easygoing", "listener", "skillful", "assertive", "philosophical", "rebellious", "selfless", "cunning", "deceptive", "artistic", "appalling", "overbearing", "temperament", "diligent", "charitable", "disposition", "quirky", "strategic", "compulsive", "benevolent", "pessimistic", "scientific", "flamboyant", "obsessive", "selective", "oriented", "humorous", "narcissistic", "reliable", "headstrong", "manipulative", "practical", "rewarding", "refined", "resilient", "desirable", "spiritual", "tendencies", "pompous", "judgmental", "respected", "inexperienced", "compassionate", "promiscuous", "argumentative", "conventional", "intellectually", "expressive", "impractical", "observant", "fickle", "hyperactive", "immoral", "straightforward", "vindictive"]
            # openai, gensim
            # used to index vector_db
            "embedding_params": {
                "api": "openai",
                "model": "text-embedding-3-large"
            },

            # <embedding_model, inference_col_name,lexicon set>
            # 2, 6 models, 16 16*6=96 combinations for fixed embedding model
            # unused
            "intellect_lexicon": ["intellectual", "intuitive", "imaginative", "knowledgeable", "ambitious", "intelligent", "opinionated", "admirable", "eccentric", "crude", "likable", "empathetic", "superficial", "tolerant", "resourceful", "uneducated", "academically", "studious", "temperamental", "exceptional", "cynical", "outspoken", "destructive", "dependable", "amiable", "impulsive", "frivolous", "insightful", "overconfident", "charismatic", "prideful", "influential", "likeable", "unconventional", "educated", "flawed", "articulate", "pretentious", "perceptive", "vulgar", "easygoing", "listener", "skillful", "assertive", "philosophical", "rebellious", "selfless", "cunning", "deceptive", "artistic", "appalling", "overbearing", "temperament", "diligent", "charitable", "disposition", "quirky", "strategic", "compulsive", "benevolent", "pessimistic", "scientific", "flamboyant", "obsessive", "selective", "oriented", "humorous", "narcissistic", "reliable", "headstrong", "manipulative", "practical", "rewarding", "refined", "resilient", "desirable", "spiritual", "tendencies", "pompous", "judgmental", "respected", "inexperienced", "compassionate", "promiscuous", "argumentative", "conventional", "intellectually", "expressive", "impractical", "observant", "fickle", "hyperactive", "immoral", "straightforward", "vindictive"]
        },
        "story": {
          
            # for non-outer loop across api but inner loop across api models
            "api": "llama2",
            # set to None, not False if omitting. limit per process
            "limit": None,
            # overwrite filled-in inferences
            "overwrite": False,

            # put subj name here
            # scheme in {story, sentence, e2e}
            "scheme": "",
            # {cot, messages}
            "msg_scheme": ["chance"],
            # stories per proc
            "num_stories": 10
        },
        "evaluation": {
            "column_name": "intellect_avg",
            "save_dir": "", 
            "probs": False, 
            "experiment_tracker_load_path": "", 
            "run_name": "testing",
            "save_run": True 
        }

    }

    debug = False
    NUM_DEBUG_SAMPLES = 20
    save_dir = ""
    run_name = "original_lexicon"
    # reproducible split
    # split_dataset(f_name, save_dir, debug, NUM_DEBUG_SAMPLES)

    # xattr_comet + scores.  backup is v5
    f_name = ""
    # dont include forward slash
    save_dir = ""
    run_name = "original_lexicon"

    # resume_from_stage = PipelineStage.STORY
    resume_from_stage = PipelineStage.INFERENCE
    # resume_from_stage = PipelineStage.LEXICON_SCORES
    # resume_from_stage = PipelineStage.FIND_SUBJECT
    only_run = True
    # resume_from_stage = PipelineStage.PARSE
    # resume_from_stage = PipelineStage.LEXICON_SCORES
    # only_run = True
    run_pipeline(f_name, save_dir, run_name=run_name, resume_from_stage=resume_from_stage, 
                 only_run=only_run, debug=debug, pipeline_config=pipeline_config)



if __name__ == "__main__":
    main()
    


