from tqdm import tqdm
import joblib

import datasets
import pandas as pd
from langchain.llms import OpenAIChat
from langchain import PromptTemplate, HuggingFacePipeline
from langchain.chains import LLMChain
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, GridSearchCV
from scipy.stats.stats import pearsonr
import numpy as np

lm = OpenAIChat(model_name='gpt-3.5-turbo', temperature=0)
lm = HuggingFacePipeline.from_model_id(model_id=f"google/flan-t5-large", task="text2text-generation", model_kwargs={"temperature":0},device=0)

feature_extractor_template = """
You will be given a Reddit post and a reply. Your job is to judge whether the reply {attribute_desc}. Score that on a scale from 1 to 10 where 1 means {attr_min} and 10 means {attr_max}.

POST:
{history}

Reply:
{reply}

Answer by outputting a number from 1 to 10 (and nothing else).

Answer:"""
feature_extractor_prompt = PromptTemplate(
    input_variables=["history", "reply", "attribute_desc", "attr_min", "attr_max"],
    template=feature_extractor_template,
)
feature_extractor_chain = LLMChain(llm=lm, prompt=feature_extractor_prompt, output_key="score")


FEATURES = {
    'helpfulness': {
        'attribute_desc': "is helpful for the original poster",
        'attr_min': "not helpful",
        'attr_max': "very helpful",
    },
    'specificity': {
        'attribute_desc': "is specific enough",
        'attr_min': "too vague",
        'attr_max': "very specific",
    },
    'intent': {
        'attribute_desc': "understands the original poster's intent",
        'attr_min': "failure of understanding",
        'attr_max': "perfect understanding",
    },
    'factuality': {
        'attribute_desc': "is factually correct",
        'attr_min': "egregiously incorrect",
        'attr_max': "fully correct",
    },
    'easy-to-understand': {
        'attribute_desc': "is easy to understand",
        'attr_min': "very difficult to understand",
        'attr_max': "very easy to understand",
    },
    'relevance': {
        'attribute_desc': "is relevant to the original poster's question",
        'attr_min': "off-topic",
        'attr_max': "very relevant",
    },
    'readability': {
        'attribute_desc': "is easy to read and not too technical for the original poster",
        'attr_min': "very difficult to read",
        'attr_max': "very easy to read",
    },
    'enough-detail': {
        'attribute_desc': "provides enough detail to be helpful",
        'attr_min': "too little detail",
        'attr_max': "very detailed",
    },
    'biased:': {
        'attribute_desc': "is biased or one-sided",
        'attr_min': "very biased",
        'attr_max': "not biased at all",
    },
    'fail-to-consider-individual-preferences': {
        'attribute_desc': "fails to consider the original poster's cultural or individual preferences",
        'attr_min': "takes into account the original poster's preferences",
        'attr_max': "fails to consider the original poster's preferences",
    },
    'repetetive': {
        'attribute_desc': "is repetitive",
        'attr_min': "very repetitive",
        'attr_max': "not repetitive",
    },
    'fail-to-consider-context': {
        'attribute_desc': "fails to consider the original poster's context",
        'attr_min': "fails to consider the original poster's context",
        'attr_max': "takes into account the original poster's context",
    },
    'too-long': {
        'attribute_desc': "is too long",
        'attr_min': "too long",
        'attr_max': "not too long",
    },
}


def get_score(history, reply, **feature_kwargs):
    inputs = {'history': history, 'reply': reply, **feature_kwargs}
    try:
        return int(feature_extractor_chain(inputs)['score'])
    except Exception as e:
        print(e)
        return None


def annotate(element):
    for feature, feature_kwargs in FEATURES.items():
        element[feature + '_A'] = get_score(history=element['history'], reply=element['human_ref_A'], **feature_kwargs)
        element[feature + '_B'] = get_score(history=element['history'], reply=element['human_ref_B'], **feature_kwargs)
    return element


def is_correct(row):
    return int(row['helpfulness_A'] >= row['helpfulness_B'] and row['labels'] == 1
               or row['helpfulness_A'] < row['helpfulness_B'] and row['labels'] == 0)


dataset_A = datasets.load_dataset('anonymous/shp_with_features_20k', split='train').to_pandas()
dataset_B = datasets.load_dataset('anonymous/shp_with_features_20k', split='test').to_pandas()


def train_model(dataset, model_name):
    features = [f'{feature_name}_A' for feature_name in FEATURES.keys()] + [f'{feature_name}_B' for feature_name in FEATURES.keys()]
    pipeline = Pipeline(steps=[('scaler', StandardScaler()), ('classifier', LogisticRegression())])
    param_grid = {
        'classifier__penalty': ['l1', 'l2'],
        'classifier__C': np.logspace(-5, 5, 12),
        'classifier__solver': ['liblinear', 'saga']
    }
    grid_search = GridSearchCV(pipeline, param_grid, cv=8, scoring='accuracy', n_jobs=-1)
    grid_search.fit(dataset[features], dataset['labels'])
    print(f"{model_name} Best hyperparameters:", grid_search.best_params_)
    best_model = grid_search.best_estimator_
    print(f"{model_name}  accuracy:", grid_search.best_score_)
    joblib.dump(best_model, f'{model_name}.joblib')
    return best_model


if joblib.load('model_A.joblib') is None:
    model_A = train_model(dataset_A, 'model_A')
    model_B = train_model(dataset_B, 'model_B')
else:
    print('Loading trained models...')
    model_A = joblib.load('model_A.joblib')
    model_B = joblib.load('model_B.joblib')

weights_A = model_A.named_steps.classifier.coef_[0, :len(FEATURES)]
means_A = model_A.named_steps.scaler.mean_[:len(FEATURES)]
stds_A = model_A.named_steps.scaler.scale_[:len(FEATURES)]
weights_B = model_B.named_steps.classifier.coef_[0, :len(FEATURES)]
means_B = model_B.named_steps.scaler.mean_[:len(FEATURES)]
stds_B = model_B.named_steps.scaler.scale_[:len(FEATURES)]
print('Model A weights:', weights_A)
print('Model B weights:', weights_B)
print('Model A means:', means_A)
print('Model B means:', means_B)
print('Model A stds:', stds_A)
print('Model B stds:', stds_B)


def recalibrate_scores(scores, weights, means, vars):
    normalized_scores = (scores - means) / vars
    return np.dot(normalized_scores, weights)


t5_samples = pd.read_json('generated_flan_t5_large.json').to_dict()
all_scores_A, all_scores_B = [], []
bon_data = pd.DataFrame(columns=['n', 'best_of_A_score_A', 'best_of_A_score_B', 'response', 'prompt'])
for _, entry in tqdm(list(t5_samples.items())[:100]):
    prompt = entry['prompt']
    entry['score_A'] = []
    entry['score_B'] = []
    for response in entry['response'][:16]:
        scores = [get_score(history=prompt, reply=response, **feature_kwargs) for feature, feature_kwargs in FEATURES.items()]
        if not all(score is not None for score in scores):
            continue
        score_A = recalibrate_scores(scores, weights_A, means_A, stds_A)
        score_B = recalibrate_scores(scores, weights_B, means_B, stds_B)
        entry['score_A'].append(score_A)
        entry['score_B'].append(score_B)
        all_scores_A.append(score_A)
        all_scores_B.append(score_B)
    for n in [1, 2, 4, 8, 16]:
        best_of_A = np.argmax(entry['score_A'][:n])
        best_of_A_score_A = entry['score_A'][best_of_A]
        best_of_A_score_B = entry['score_B'][best_of_A]
        bon_data = bon_data.append({'n': n, 'best_of_A_score_A': best_of_A_score_A,
                                    'best_of_A_score_B': best_of_A_score_B, 'response': entry['response'][best_of_A],
                                    'prompt': prompt}, ignore_index=True)

bon_data.to_csv('bon.csv')
sns.set_theme('paper')
plt.scatter(all_scores_A, all_scores_B)
plt.xlabel('Score according to PM A')
plt.ylabel('Score according to PM A')
plt.show()
# correlation
print('corr', pearsonr(all_scores_A, all_scores_B))
df = pd.DataFrame.from_dict(t5_samples, orient='index')
df.to_json('generated_flan_t5_large_with_scores.json')

df = pd.read_json('generated_flan_t5_large_with_scores.json')
bon = pd.read_csv('bon.csv')
fig, ax = plt.subplots()
sns.lineplot(data=bon, x='n', y='best_of_A_score_A', ax=ax, legend='full', label='PM A score (used for argmax)')
sns.lineplot(data=bon, x='n', y='best_of_A_score_B', ax=ax, legend='full', label='PM B score')
# plt.xscale('log')
ax.set_xlabel('n')
ax.set_ylabel('Score')
plt.show()