from collections import Counter, defaultdict
import json
from nltk import edit_distance
import os
import pandas as pd
import numpy as np
import openai
import seaborn as sns
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, pipeline
from datasets import Dataset, DatasetDict, load_metric, load_from_disk
from sklearn.metrics import classification_report
import time

OPEN_AI_API_KEY = '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'

def reframe_counterfactual(df):
    """
    Restructures `df` for easier counterfactual analysis.
    In the reframed dataframe, every row represents a counterfactual edit,
    with new columns that keep track of the original description,
    aspect majority, and length.
    
    Parameters
    ----------
    df : Pandas DataFrame
        Structured as in the provided json file. Each row represents
        either an original or edited restaurant review, and is labeled
        for aspect/edit goal, aspect sentiment, and overall sentiment.
    
    Returns
    -------
    Pandas DataFrame
        Restructured so that each row represents an edit (with `aspect`, `edit_goal`,
        `description`, `aspect_majority`, and `review_majority` columns), aligned with the original 
        restaurant review it was meant to edit (with `description`, `original_aspect_majority`, and
        `original_review_majority` columns).
    
    """
    original_df = df[df['is_original']].rename(columns={'edit_type': 'aspect'})
    cf_df = df[~df['is_original']].rename(columns={'edit_type': 'aspect'})

    original_df = original_df.set_index(original_df['original_id'])
    cf_df = cf_df.set_index(cf_df['original_id'])
    
    cf_df['aspect_majority'] = cf_df.apply(lambda r: r[f"{r['aspect']}_aspect_majority"], axis=1)
    
    # align original aspect majority as relevant to the edited counterfactual's aspect
    cf_df['original_aspect_majority'] = 0
    for aspect in cf_df.aspect.unique():
        cf_df['original_aspect_majority'].loc[cf_df['aspect'] == aspect] = original_df[f'{aspect}_aspect_majority']

    cf_df['original_description'] = original_df['description']
    cf_df['original_review_majority'] = original_df['review_majority']
    cf_df['length'] = original_df['description'].apply(lambda s: len(s.split()))

    cf_df = cf_df.reset_index(drop=True)
    return cf_df
    
def few_shot_sample(ex, 
                    df, 
                    n=2,
                    only_correct=True, 
                    match_edit_goal=True,
                    match_aspect=True,
                    match_original_rev_majority=False,
                    match_original_aspect_majority=True,
                    match_length=False,
                    is_same_df=False,
                    sample=True):
    """
    Chooses `n`-shot samples for a particular example `ex` from the counterfactual
    dataframe `df`. Chosen samples can be filtered by the boolean arguments, as explained 
    below.
    
    Parameters
    ----------
    ex : Pandas DataFrame row
        Single example for which we want to find few-shot samples for.
    df : Pandas DataFrame
        Counterfactual dataframe, from which to choose the samples. 
    n : int, default 2
        Number of few-shot samples to generate. 
    only_correct : bool, default True
        Flag determining whether to filter the sample pool to only include 
        edits which we consider "correct" - when the aspect majority
        as labeled by other MTurk users matches the edit goal for the counterfactual edit.
    match_edit_goal : bool, default True
        Flag determining whether to filter the sample pool to only include
        edits that matched the original edit goal (e.g. if the task is to make the review negative
        about the service, only choose samples that make the ____ negative.)
    match_aspect : bool, default True
        Similar to `match_edit_goal`, but with respect to the aspect. For example, if
        the task is to make the review negative about the service, only choose samples that 
        make the service _____).
    match_original_rev_majority : bool, default False
        Flag determining whether to filter the sample pool to only include
        edit tasks where the original descriptions match in overall sentiment. 
    match_original_aspect_majority : bool, default True
        Flag determining whether to filter the sample pool to only include 
        edit tasks where the original descriptions match in the specific aspect-level sentiment.
        Only set this True if `match_aspect` is also set to True.
    match_length : bool, default False
        Flag determining whether to filter the sample pool to only include 
        edit tasks where the original descriptions match in number of words.
    is_same_df : bool, default False
        Only applies if `ex` is drawn from `df`, although this shouldn't often be the case.
        If it is, then setting `is_same_df` to True ensures that the sampled example doesn't match the id
        of the provided example.
    sample : bool, default True
        True to sample `n` examples randomly from `df`, False to deterministically choose the first `n`.
        
    Returns
    -------
    Pandas DataFrame
        Returns `n` sample rows, chosen randomly from a sample pool
        in `df` filtered by the boolean flags.
    
    """
    if n == 0:
        return []
    
    filter_id = (df['id'] != ex['id']) | (not is_same_df)
    filter_correct = (df['edit_goal'] == df['aspect_majority']) | (not only_correct)
    filter_edit_goal = (df['edit_goal'] == ex['edit_goal']) | (not match_edit_goal)
    filter_aspect = (df['aspect'] == ex['aspect']) | (not match_aspect)
    filter_o_rev_majority = (df['original_review_majority'] == ex['original_review_majority']) | (not match_original_rev_majority)
    filter_o_asp_majority = (df['original_aspect_majority'] == ex['original_aspect_majority']) | (not match_original_aspect_majority)
    
    fs = df[filter_id & filter_correct & filter_edit_goal & filter_aspect & filter_o_rev_majority & filter_o_asp_majority]
    # in the case that our sample count is too small, we will simply generate however many are left
    if len(fs) < n:
        return fs.sample(len(fs))
    
    # will roughly match length in order to generate `n` samples
    i = 0
    prev = -1
    while prev < n:
        i += 1
        filter_length = (abs(df['length'] - ex['length']) < i) | (not match_length)
        few_shot_df = df[filter_id & filter_correct & filter_edit_goal & filter_aspect & filter_o_rev_majority & filter_o_asp_majority & filter_length]
        prev = len(few_shot_df)
        
    if sample:
        return few_shot_df.sample(n)
    else:
        return few_shot_df[:n]


def generate_prompt(ex, few_shots, start_prompt=True, upper=True, joiner='\n\n'):
    """
    Generates prompt for few-shot learning. An example:
    
    Please make the following restaurant reviews NEGATIVE about the FOOD.
    
    Original: I loved the pasta!
    
    NEGATIVE about the FOOD: I hated the pasta!
    
    Original: I thought the service was mediocre.
    
    NEGATIVE about the FOOD: I thought the hamburger was mediocre.
    
    Original: I enjoyed the music and the atmosphere.
    
    NEGATIVE about the FOOD:
    
    Parameters
    ----------
    ex : Pandas DataFrame row
        Single example for which we want to generate a counterfactual. 
    few_shots : Pandas DataFrame
        Counterfactual examples, as chosen by `few_shot_sample`
    start_prompt : bool, default True
        Whether or not to include a prompt at the beginning. 
        Only set to True if the few-shots match in aspect and edit goal.
    
    Returns
    -------
    string
        Few-shot prompt to provide to a language-generating model. 
    """
    prompt = []

    aspect = ex.aspect.upper() if upper else ex.aspect.lower()
    edit_goal = ex.edit_goal.upper() if upper else ex.edit_goal
    if edit_goal.upper() == 'UNKNOWN':
        edit_goal = 'WITHOUT' if upper else 'Without'

    if start_prompt:
        eg_prompt = edit_goal if upper else edit_goal.lower()
        edit_goal_prompt = f'include {eg_prompt}' if eg_prompt.upper() != 'WITHOUT' else eg_prompt
        prompt += [f'Make the following restaurant reviews {edit_goal_prompt} mentions of {aspect}.']

    for i in range(len(few_shots)):
        asp = few_shots.iloc[i].aspect.upper() if upper else few_shots.iloc[i].aspect.lower()
        eg = few_shots.iloc[i].edit_goal.upper() if upper else few_shots.iloc[i].edit_goal
        if eg.upper() == 'UNKNOWN':
            eg = 'WITHOUT' if upper else 'Without'
        prompt += [f'Original: {few_shots.iloc[i].original_description}']
        prompt += [f'{eg} mentions of {asp}: {few_shots.iloc[i].description}']
    
    prompt += [f'Original: {ex.original_description}']
    prompt += [f'{edit_goal} mentions of {aspect}:']

    # clean out any newlines within sentences
    prompt = [' '.join(l.split('\n')) for l in prompt]
    return joiner.join(prompt)
    

def generate_prompts(sample_df, train_df, nshot, start_prompt=True, **fs_kwargs):
    """
    Generates prompts for all samples in `sample_df`, where the few-shot
    samples are taken from `train_df`.
    """
    def gp(ex):
        few_shots = few_shot_sample(ex, train_df, n=nshot, **fs_kwargs)
        return generate_prompt(ex, few_shots, start_prompt=start_prompt)
    
    return list(sample_df.apply(gp, axis=1))


def generate_finetune_data(df, nshot=0, start_prompt=True):
    """
    Generates prompts for fine-tuning (0-shot recommended), and reformats
    them according to the GPT-3 API guides.
    https://beta.openai.com/docs/guides/fine-tuning
    
    Parameters
    ----------
    df : Pandas DataFrame 
        Dataset on which we want to finetune our model.
    nshot : int, default 0
        Number of examples to include in a prompt (recommended 0).
    start_prompt : bool, default True
        Whether or to include a start prompt.
    
    Returns
    -------
    List of dicts
        Finetuning data, consisting of prompt, completion pairs:
        {'prompt': 'Original: I thought the service was mediocre.\n\nNEGATIVE about the FOOD:',
         'completion': 'I thought the hamburger was mediocre.\n'}
    """
    prompts = generate_prompts(sample_df=df, train_df=df, nshot=nshot, start_prompt=start_prompt)
    labels = df.description
    
    data = [{'prompt': p, 'completion': f' {l}\n'} for p, l in zip(prompts, labels)]
    return data

def _find_generated_answer(tokens, newline="\n" ): 
    """Our LMs tend to insert initial newline characters before
    they begin generating text. This function ensures that we 
    properly capture the true first line as the answer while
    also ensuring that token probabilities are aligned."""        
    answer_token_indices = []
    char_seen = False            
    for i, tok in enumerate(tokens):
        # This is the main condition: a newline that isn't an initial
        # string of newlines:
        if tok == newline and char_seen:
            break
        # Keep the initial newlines for consistency:
        elif tok == newline and not char_seen:
            answer_token_indices.append(i)
        # Proper tokens:
        elif tok != newline:
            char_seen = True
            answer_token_indices.append(i)
    return answer_token_indices 

def run_gpt3(prompts, engine="text-curie-001", model=None, temperature=0.9, max_tokens=42, **gpt3_kwargs):
    """
    Runs GPT-3 on a list of prompts.
    
    Parameters
    ----------
    prompts : iterable of str
    engine : str  
        https://beta.openai.com/docs/engines/gpt-3                
    temperature : float
        It seems best to set it high for this task!
    max_tokens: int
        Limits how many tokens the model is asked to generate.
        
    For information about values for `gpt3_kwargs`, see
    
    https://beta.openai.com/docs/api-reference/completions
    
    Returns
    -------
    list of dicts   
    """
    openai.api_key = OPEN_AI_API_KEY

    assert (engine is not None) or (model is not None), 'Please provide an engine or a finetuned model id.'

    # go with pretrained model if provided, else use engine
    if model is not None:
        gpt3_kwargs['model'] = model
    else:
        gpt3_kwargs['engine'] = engine
        
    response = openai.Completion.create(
        prompt=prompts,
        temperature=temperature,
        echo=False,   # This function will not work
        logprobs=1,   # properly if any of these
        n=1,          # are changed!
        max_tokens=max_tokens,
        **gpt3_kwargs)
    
    # From here, we parse each example to get the values
    # we need:
    data = []
    for ex, prompt in zip(response["choices"], prompts):
        tokens = ex["logprobs"]["tokens"]
        logprobs = ex["logprobs"]["token_logprobs"]        
        probs = list(np.exp(logprobs))
        if "<|endoftext|>" in tokens:
            end_i = tokens.index("<|endoftext|>")
            tokens = tokens[ : end_i]  # This leaves off the "<|endoftext|>"
            probs = probs[ : end_i]    # token -- perhaps dubious.
        ans_indices = _find_generated_answer(tokens)
        answer_tokens = [tokens[i] for i in ans_indices]
        answer_probs = [probs[i] for i in ans_indices]
        answer = "".join(answer_tokens)        
        data.append({
            "prompt": prompt,
            "generated_text": ex["text"],
            "generated_tokens": tokens,
            "generated_probs": probs,
            "generated_answer": answer,
            "generated_answer_tokens": answer_tokens,
            "generated_answer_probs": answer_probs})
        
    return data

def upload_finetune_data_gpt3(filename, data, delete=True):
    """
    Uploads finetune data to the GPT-3 API.
    The data should be structured as shown in `generate_finetune_data`.
    """
    openai.api_key = OPEN_AI_API_KEY
    with open(filename, 'w') as f:
        for t in data:
            json.dump(t, f)
            f.write('\n')
    
    with open(filename) as f:
        response = openai.File.create(
            file=f,
            purpose='fine-tune'
        )
    
    if delete:
        os.remove(filename)
        
    return response['id']

def finetune_gpt3(train, val=None, engine='curie', asyn=False, **gpt3_kwargs):
    """
    Finetunes GPT-3 on a training set, with an optional validation set.
    
    Parameters
    ----------
    train : iterable of dict
        Training data, formatted as in `generate_finetune_data`.
    val : iterable of dict, optional
        Validation data, formatted as in `generate_finetune_data`.
    engine : str  
        https://beta.openai.com/docs/engines/gpt-3                
    asyn : bool, default False
        Whether to stall until the finetuning is complete, or exit the function.
        
    For information about values for `gpt3_kwargs`, see
    
    https://beta.openai.com/docs/api-reference/fine-tunes
    
    Returns
    -------
    dict 
        'model': name of finetuned model for training (`None` for `async=True`)
        'id': id of finetune job (relevant mostly for `async=True`)
        'results': id of result files, for investigation of training and validation scores.
    """
    openai.api_key = OPEN_AI_API_KEY
    
    # upload train data to GPT-3 API
    train = upload_finetune_data_gpt3('train-temp.jsonl', train)
    if val is not None:
        val = upload_finetune_data_gpt3('val-temp.jsonl', val)
        gpt3_kwargs['validation_file'] = val
    
    # create FineTune task
    response = openai.FineTune.create(
        training_file=train,
        model=engine,
        **gpt3_kwargs
    )
    
    # if we are asynchronous, no need to stall until finetuning job is complete
    if asyn:
        return {'model': None, 'id': response['id'], 'results': None}

    jobid = response['id']
    print(f'Model queued with id {jobid}.')
    running = False
    while response['status'] != 'succeeded':
        time.sleep(30)
        response = openai.FineTune.retrieve(id=jobid)
        if response['status'] == 'running' and not running:
            print('Model is running.')
            running = True

    print(f"Model finished with id {response['fine_tuned_model']}.")

    return {'model': response['fine_tuned_model'], 'id': response['id'], 'results': response['result_files'][0]['id']}

def run_gpt3_experiment(name,
                        train_df, 
                        dev_df, 
                        finetune=True, 
                        nshot=0, 
                        engine='curie', 
                        start_prompt=True, 
                        temperature=0.7, 
                        batch_size=20,
                        ft_kwargs={},
                        fs_kwargs={},
                        **gpt3_kwargs):
    """
    Uses GPT-3 to generate counterfactuals on `dev_df`, using `train_df` to either
    finetune the model or to sample few-shot examples. 
    
    Parameters
    ----------
    name : string
        Name of the experiment. GPT-3 outputs will be saved to a json file with 
        the provided name.
    train_df : Pandas DataFrame
        Training data, formatted as output of `reframe_counterfactual`.
    dev_df : Pandas DataFrame
        Validation data, formatted as output of `reframe_counterfactual`.
    finetune : bool, default True  
        Whether to finetune a GPT-3 model using the training set.           
    engine : string, default 'curie'
        One of 'ada', 'babbage', 'curie', or 'davinci' (will automatically
        take the 'text' option of the model)
        https://beta.openai.com/docs/engines/gpt-3
    start_prompt : bool, default True
        Whether prompts should contain an initial prefix.
    temperature : float, default 0.7
        Temperature of GPT-3 model (higher -> more creative)
    batch_size : int, default 20
        Batch sizes for GPT-3 API (20 is largest for free acount)
    ft_kwargs : dict
        Keyword arguments for finetuning, `finetune_gpt3`.
    fs_kwargs : dict
        Keyword arguments for few-shot generation, `few_shot_sample`.
    gpt3_kwargs 
        Keyword arguments for the GPT-3 model, `run_gpt3`
        https://beta.openai.com/docs/api-reference/fine-tunes
    
    Returns
    -------
    list of dicts
        See `run_gpt3`.
    """
    if finetune:
        train_finetune = generate_finetune_data(train_df, nshot=nshot, start_prompt=start_prompt)
        dev_finetune = generate_finetune_data(dev_df, nshot=nshot, start_prompt=start_prompt)

        # going with the defaults for finetuning right now -- should we play with n_epochs and learning rate?
        finetune_results = finetune_gpt3(train_finetune, 
                                dev_finetune, 
                                engine=engine, 
                                **ft_kwargs)
        
        model = finetune_results['model']
    else:
        model = None
        engine = 'text-' + engine + '-00' + ('2' if engine == 'davinci' else '1')
        finetune_results = []
    
    prompts = generate_prompts(dev_df, train_df=train_df, nshot=nshot, start_prompt=start_prompt, **fs_kwargs)
    
    output = []
    
    for b in range(0, len(prompts), batch_size):
        batch = prompts[b:b + batch_size]
        output += run_gpt3(batch, engine=engine, model=model, temperature=temperature, **gpt3_kwargs)
    
    with open(f'{name}.json', 'w') as f:
        json.dump(output, f)
    
    return output, finetune_results
