
from collections import Counter, defaultdict
import gc
import json
import re
import pandas as pd
import pickle

from tqdm import tqdm
import torch
import torch.distributed as dist

from alignment_research.Uncover_implicit_bias.pipeline_state import PipelineConfig, PipelineState
from alignment_research.Uncover_implicit_bias.distributed_tools.tools import all_gather_data, deserialize_data, gather_data, serialize_data, split_data 
from alignment_research.LLMs.utils import parse_response, initialize_model, load_llm_parser


def writeFile(file_name, content):
    a = open(file_name, 'a')
    a.write(content)
    a.close()


def saveFile(outputname, content):
    print(f"Saving file {outputname}", flush=True)
    file = open(outputname, 'wb')
    pickle.dump(content, file)



def majority_vote(predictions):
    # Compute the majority vote for the entire list of predictions
    majority_voted_inference = Counter(predictions).most_common(1)[0][0]
    return majority_voted_inference

def get_characters(story_row):
    """
    Extract characters from the 'roles' column of a story row.
    
    Args:
    story_row (pd.Series): A row from the stories DataFrame
    
    Returns:
    list: A list of character roles
    """
    if 'roles' not in story_row:
        raise KeyError("The 'roles' column is missing from the story data.")
    
    roles_string = story_row['roles']
    
    # Split the string into a list of roles
    roles = [role.strip() for role in roles_string.split(',')]
    
    # Remove any empty strings that might result from trailing commas
    roles = [role for role in roles if role]
    
    return roles

def getDict(state:PipelineState, pipeline_config:PipelineConfig, mode="xAttr", outer_api=None):
    """mode can be part of state

    dont resave a new dataset but just add a new col to existing dataset if only running this stage on an existing
    dataset with pipeline steps 1-2 already run -- this is default behaviour as long as we name the cols something
    different
    """

    global_rank = state.distributed_state.global_rank
    local_rank = state.distributed_state.local_rank
    world_size = state.distributed_state.world_size
    node_name = state.distributed_state.node_name
    save_dir = state.save_dir

    f_name = state.f_name

    # load model for making inferences. model.predict(**params)
    # might be passed in via outer loop
    if outer_api:
        api = outer_api 
    else:
        api = pipeline_config["inference"]['api']

    prompt_lib = pipeline_config['general']["prompt_library"][0]['generate_inferences']
    inference_limit = pipeline_config["inference"]["limit"]
    # overwrite = pipeline_config["inference"]["overwrite"]
    # applies to creating a new df
    general_overwrite = pipeline_config['general']['overwrite']
    # applies to overwriting vals in an existing df
    inference_overwrite = pipeline_config['inference']['overwrite']
    prompts = pipeline_config['inference']['msg_scheme']

    parser = pipeline_config['general']['parser']
    llm_parser = pipeline_config['general']['llm_parser']
    scheme = pipeline_config['inference']['scheme']
    msg_schemes = pipeline_config['inference']['msg_scheme']
    num_generations = pipeline_config['inference']['num_generations']

    story_col = pipeline_config['inference']['story_col_name']
    subj_name = pipeline_config['inference']['gpt_subj_name']

    print(f'Using scheme: {scheme}')
    print(f'msg schemes: {msg_schemes}')
    print(f'Num generations: {num_generations}')
    print(f'Story col: {story_col}', flush=True)

    if parser and llm_parser:
        parsing_params, parsing_prompter = load_llm_parser(pipeline_config)
        # parsing_prompter = prompt_llama2
        # parsing_prompter = prompt_claude
    
    if parser and not llm_parser:
        parsing_params = False
        parsing_prompter = False

    generate_params, predict, model, model_name, four_bit = initialize_model(api, prompt_lib, local_rank=local_rank)
    
    for msg_scheme in msg_schemes:

        if model != 'comet':
            print(f'Current msg scheme: {msg_scheme}')
            generate_params["messages"] = generate_params["prompt_config"][scheme][msg_scheme]

        if api != 'llama2' and api != 'mixtral' and api != 'gemma' and api != 'comet':
            column_name = f'intellect_xattr_api={api}_model={model}_scheme={scheme}_msg_scheme={msg_scheme}_numgen={num_generations}'

        elif api == 'comet':
            column_name = f'intellect_xattr_api={api}_model={model}_numgen={num_generations}'
        else:
            column_name = f'intellect_xattr_api={api}_model={model_name}_scheme={scheme}_msg_scheme={msg_scheme}_numgen={num_generations}_4bit={four_bit}' 

        column_name = f'{column_name}_{story_col}'
        
        print(f'xattr column name is: {column_name}', flush=True)

        df = pd.read_csv(f_name) 
        # Rename 'narrative' column to 'story' if it exists
        if 'narrative' in df.columns:
            df = df.rename(columns={'narrative': story_col})

        # Add 'story_id' column if it doesn't exist
        if 'story_id' not in df.columns:
            df['story_id'] = range(1, len(df) + 1)

        # debug
        # df = df.head(2)

        # df.set_index("story_id")
        f = df 
        # run_name = pipeline_config["general"]["run_name"]
        # column_name = f'{mode}_{run_name}'
        
        print(f'Creating new column name: {column_name}')

        # overwrite if exists
        
        if column_name in f.keys() and inference_overwrite:
            f[column_name] = ""

        
        stories:pd.DataFrame = split_data(f, global_rank, world_size)
        cnt = 0
        
        # need to sample a character from the narrative to classify
        for row in tqdm(stories.iterrows(), total=len(stories), desc=f'{scheme} iteration'):
            story_idx = row[1]['story_id']
            print(f'\nStory ID: {story_idx}', flush=True)

            story = row[1][story_col]

            # Get all characters in the story
            characters = get_characters(row[1])
            
            all_character_inferences = defaultdict(list)
            all_character_raw_inferences = defaultdict(list)

            if not inference_overwrite:
                if column_name in row[1] and (pd.isna(row[1][column_name]) or row[1][column_name] == '""'):
                    pass
                elif column_name in row[1]:
                    continue
                else:
                    pass

            if inference_limit == 0:
                break 

            if inference_limit:
                inference_limit -= 1
            
            if scheme == 'sentence':
                # Check if gpt_subj is NaN
                if pd.isna(row[1][subj_name]):
                    print(f"Skipping story ID: {row[1]['story_id']} due to NaN in {subj_name}")
                    continue

                try:
                    # Extract the JSON content from the string
                    json_content = re.search(r'```json\s*(.*?)\s*```', row[1][subj_name], re.DOTALL)
                    if json_content:
                        gpt_subj = json.loads(json_content.group(1))
                    else:
                        # If the content is not wrapped in ```json ```, try parsing it directly
                        gpt_subj = json.loads(row[1][subj_name])
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON for story ID: {row[1]['story_id']}")
                    print(f"JSON error: {e}")
                    print(f"Raw content: {row[1][subj_name]}")

                    continue
                except Exception as e:
                    print(f"Unexpected error processing story ID: {row[1]['story_id']}")
                    print(f"Error: {e}")
                    continue

                for character in characters:
                    # Use only sentences pertaining to the current character
                    txt = gpt_subj['character_sentences'].get(character, [])
                    print(f'Num sentences for character {character}: {len(txt)}', flush=True)
                    # Modify generate_params to include character information
                    character_generate_params = generate_params.copy()
                    character_generate_params['character'] = character
                    
                    for st in txt:
                        cnt += 1
                        if mode == "xAttr":
                            # Generate multiple inferences and average them
                            all_inferences = []
                            unparsed_inferences = []

                            for i in range(num_generations):
                                inference = predict(**character_generate_params, story=st, level='sentence')
                                if parser:
                                    if scheme == 'e2e':
                                        parsed_inference = parse_response(inference, None, response_type='score', pointwise=True, llm_parser=parsing_params, prompter=parsing_prompter)
                                    else:
                                        parsed_inference = parse_response(inference, None, response_type='inference', pointwise=True, llm_parser=parsing_params, prompter=parsing_prompter)
                                else:
                                    parsed_inference = inference
                                unparsed_inferences.append(inference)
                                all_inferences.append(parsed_inference)

                            # Store inferences for this character
                            all_character_inferences[character].extend(all_inferences)
                            all_character_raw_inferences[character].extend(unparsed_inferences)

                        elif mode == "xReact":
                            inference = comet_model.predict(st, "xReact", num_beams=5)
                            all_character_inferences[character].append(inference)
                        elif mode == "oReact":
                            inference = comet_model.predict(st, "oReact", num_beams=5)
                            all_character_inferences[character].append(inference)
                        elif mode == "motivation":
                            inference = comet_model.predict(st, "xWant", num_beams=5)
                            all_character_inferences[character].append(inference)
                            
                            inference = comet_model.predict(st, "xNeed", num_beams=5)
                            all_character_inferences[character].append(inference)

                            inference = comet_model.predict(st, "xIntent", num_beams=5)
                            all_character_inferences[character].append(inference)
                
                stories.loc[stories['story_id'] == story_idx, column_name] = json.dumps(all_character_inferences)
                stories.loc[stories['story_id'] == story_idx, column_name+'_raw'] = json.dumps(all_character_raw_inferences)
            
            if scheme == 'story' or scheme == 'e2e':
                txt = row[1][story_col]
                cnt += 1
                if mode == "xAttr":
                    # Generate multiple inferences and average them
                    # all_inferences = []
                    # unparsed_inferences = []

                    for i in range(num_generations):
                        inference = predict(**generate_params, story=txt, level='story')
                        if parser:
                            if scheme == 'e2e':
                                parsed_inference = parse_response(inference, None, response_type='score', pointwise=True, llm_parser=parsing_params, prompter=parsing_prompter)
                            else:
                                parsed_inference = parse_response(inference, None, response_type='inference', pointwise=True, llm_parser=parsing_params, prompter=parsing_prompter)
                        else:
                            parsed_inference = inference
                        # unparsed_inferences.append(inference)
                        # all_inferences.append(parsed_inference)
                        stories.loc[stories['story_id'] == story_idx, column_name] = json.dumps(parsed_inference)
                        stories.loc[stories['story_id'] == story_idx, column_name+'_raw'] = json.dumps(inference)


            # if scheme == 'sentence':

                # else:
                    # stories.loc[stories['story_id'] == story_idx, column_name] = json.dumps(all_character_inferences)
                    # stories.loc[stories['story_id'] == story_idx, column_name+'_raw'] = json.dumps(all_character_raw_inferences)
            
            # if scheme == 'story' or scheme == 'e2e':
            #     # Store results for all characters
            #     stories.loc[stories['story_id'] == story_idx, column_name] = json.dumps(all_character_inferences)
            #     stories.loc[stories['story_id'] == story_idx, column_name+'_raw'] = json.dumps(all_character_raw_inferences)
            #     # stories.loc[stories['story_id'] == story_idx, column_name] = json.dumps(inference)
            # else:
            #     stories.loc[stories['story_id'] == story_idx, column_name] = json.dumps(inferences)
            #     stories.loc[stories['story_id'] == story_idx, column_name+'_raw'] = json.dumps(raw_inferences)

            # stories.loc[stories['story_id'] == story_idx, column_name+'_raw'] = json.dumps(raw_inferences)
            


        serialized_inferences = serialize_data(stories.to_dict('records'))
        gathered_tensors = all_gather_data(global_rank, local_rank, world_size, serialized_inferences)


        # print(f'Process {global_rank}. Length of all_stories = {len(all_stories)}', flush=True)
        print(f'Finished collecting data', flush=True)
        all_stories = None

        if global_rank == 0:

            # all_stories now contains the generated stories from all processes
            # Deserialize the gathered data
            all_stories = deserialize_data(gathered_tensors)
            all_stories = pd.DataFrame(all_stories)
            # parse_response(all_stories, column_name)

            # save the collected stories
            state.save_dataset(all_stories, overwrite=general_overwrite)

        dist.barrier()
    


    # Clean up and free GPU memory

    if 'generate_params' in locals() and 'model' in generate_params:
        del generate_params['model']

    if 'generate_params' in locals() and 'tokenizer' in generate_params:
        del generate_params['tokenizer']

    if 'model' in locals():
        del model

    if 'tokenizer' in locals():
        del tokenizer

    if 'pipe' in locals():
        del pipe

    gc.collect()
    torch.cuda.empty_cache()
    return all_stories

