from collections import Counter, defaultdict
import json
import matplotlib.pyplot as plt
from nltk import edit_distance
import os
import pandas as pd
import numpy as np
import openai
import seaborn as sns
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, pipeline
from datasets import Dataset, DatasetDict, load_metric, load_from_disk
from sklearn.metrics import classification_report
import time

OPEN_AI_API_KEY = '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'

f1_metric = load_metric("f1")

LABEL_TO_IND = {'Positive': 0, 'Negative': 1, 'unknown': 2, 'no majority': 2}
IND_TO_LABEL = {'LABEL_0': 'Positive', 'LABEL_1': 'Negative', 'LABEL_2': 'unknown'}
LABELS = ['Positive', 'Negative', 'unknown']
ABSA_LABELS = ['Negative', 'Positive', 'unknown']

def get_absa_df(
    df_in,
    aspect_label_encode={
        "Negative":0,
        "Positive":1,
        "unknown":2,
        "no majority": 2,
    },
):
    df = df_in.copy()
    columns_to_keep = [
        'description',
        'food_aspect_majority', 'ambiance_aspect_majority', 
        'service_aspect_majority', 'noise_aspect_majority'
    ]
    columns_to_keep += [col for col in df.columns if 'prediction' in col]
    df = df[columns_to_keep].rename(
        columns={
            'description': 'text', 
            'food_aspect_majority': 'food_label',
            'ambiance_aspect_majority': 'ambiance_label',
            'service_aspect_majority': 'service_label',
            'noise_aspect_majority': 'noise_label'
        }
    ).replace(
        {
            "food_label": aspect_label_encode,
            "ambiance_label": aspect_label_encode,
            "service_label": aspect_label_encode,
            "noise_label": aspect_label_encode
        }
    )
    
    food_label = df[["text", "food_label"]].rename(columns={'food_label': 'label'})
    food_label["aspect"] = "food"
    ambiance_label = df[["text", "ambiance_label"]].rename(columns={'ambiance_label': 'label'})
    ambiance_label["aspect"] = "ambiance"
    service_label = df[["text", "service_label"]].rename(columns={'service_label': 'label'})
    service_label["aspect"] = "service"
    noise_label = df[["text", "noise_label"]].rename(columns={'noise_label': 'label'})
    noise_label["aspect"] = "noise"
    
    absa_df = pd.concat(
        (food_label, ambiance_label, service_label, noise_label)
    )
    absa_df = absa_df[absa_df["label"]!=""]
    
    return absa_df

def run_absa_model(sentences, aspects, model, tokenizer, batch_size=200, device=0):
    """
    Runs the provided ABSA `model` on the input `sentences` and `aspects`, returning
    a list of predictions of aspect-level sentiment analysis for each (sentence, aspect) pair.
    
    Parameters
    ----------
    sentences : List of str
        List of input sentences.
    aspects : List of str
        List where each element is one of 'food', 'service', 'ambiance', or 'noise'. Specifies
        which aspect-level sentiment the model should produce.
    model
        ABSA model for aspect-level sentiment classification.
    tokenizer 
        Tokenizer for the ABSA model.
    batch_size : int, default 200
        Specifies batch size for running model, to avoid out-of-memory issues.
    
    Returns
    -------
    List of int
        Predictions for aspect-level sentiment analysis, corresponding to the indices in `ABSA_LABELS`.
    """
    results = []
    for b in range(0, len(sentences), batch_size):
        s = sentences[b:b + batch_size]
        a = aspects[b: b + batch_size]
        tokens = tokenizer(s, a, return_tensors='pt', padding=True, truncation=True).to(device)
        output = model(**tokens).logits.argmax(axis=1)
        results += list(output.clone().detach().cpu().numpy())
    return results

def test_absa_model_on_aspect(test_df, model, tokenizer, aspect, device=0):
    true_labels = [LABELS[l] for l in test_df.label.values]
    x = list(test_df.description.astype(str).values)
    a = [aspect] * len(x)
    predictions = run_absa_model(x, a, model, tokenizer, device=device)
    predicted_labels = [ABSA_LABELS[p] for p in predictions]
    print_classification_report(true_labels, predicted_labels)

def test_absa_model(cf_df, model, tokenizer, device=0):
    x = list(cf_df.description.astype(str).values)
    a = list(cf_df.aspect.values)
    true_labels = list(cf_df.aspect_majority.values)
    predictions = run_absa_model(x, a, model, tokenizer, device=device)
    predicted_labels = [ABSA_LABELS[p] for p in predictions]
    print_classification_report(true_labels, predicted_labels)
    
def validate_mturk_counterfactuals(cf_df, model, tokenizer, 
                                   aspects=['ambiance', 'food', 'noise', 'service'], without=False, device=0):
    if without:
        x = []
        a = []
        true_labels = []
        for aspect in aspects:
            # choose all rows where the aspect was NOT asked to be changed
            # and where the aspect is recorded (will be the same as the ORIGINAL aspect)
            a_cf_df = cf_df[(cf_df.aspect != aspect) & (cf_df[f'{aspect}_aspect_majority'] != '')]
            x += list(a_cf_df.description.astype(str).values)
            a += [aspect] * len(a_cf_df)
            true_labels += list(a_cf_df[f'{aspect}_aspect_majority'].values)
    else:
        aspect_cf_df = cf_df[cf_df.aspect.isin(aspects)]
        x = list(aspect_cf_df.description.astype(str).values)
        a = list(aspect_cf_df.aspect.values)
        true_labels = list(aspect_cf_df.edit_goal.values)
    predictions = run_absa_model(x, a, model, tokenizer, device=device)
    predicted_labels = [ABSA_LABELS[p] for p in predictions]
    print_classification_report(true_labels, predicted_labels)

    
def validate_gpt3_counterfactuals(gpt3_output, cf_df, model, tokenizer, 
                                  aspects=['ambiance', 'food', 'noise', 'service'], without=False, device=0):
    counterfactuals = np.array([o['generated_answer'] for o in gpt3_output])
    if without:
        c = []
        a = []
        true_labels = []
        for aspect in aspects:
            # choose all rows where the aspect was NOT asked to be changed
            # and where the aspect is recorded (will be the same as the ORIGINAL aspect)
            a_cf_df_indices = (cf_df.aspect != aspect) & (cf_df[f'{aspect}_aspect_majority'] != '')
            a_cf_df = cf_df[a_cf_df_indices]
            c += list(counterfactuals[a_cf_df_indices])
            a += [aspect] * len(a_cf_df)
            true_labels += list(a_cf_df[f'{aspect}_aspect_majority'].values)
    else:
        aspect_indices = cf_df.aspect.isin(aspects)
        c = list(counterfactuals[aspect_indices])
        a = list(cf_df[aspect_indices].aspect.values)
        true_labels = list(cf_df[aspect_indices].edit_goal.values)
    predictions = run_absa_model(c, a, model, tokenizer, device=device)
    predicted_labels = [ABSA_LABELS[p] for p in predictions]
    print_classification_report(true_labels, predicted_labels)