from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import json
from tqdm import tqdm
import pandas as pd
from random import sample
import csv 
import random

device = "cuda" # the device to load the model onto

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

####################################
def get_mistral_answer(prompt, role="", n=1):
    messages = [{"role": "user", "content": role + " " + prompt}]
    encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
    model_inputs = encodeds.to(device)
    model.to(device)
    generated_ids = model.generate(model_inputs, max_new_tokens=1500, do_sample=True, num_return_sequences=n)
    
    if n == 1:
        assistant_message = tokenizer.batch_decode(generated_ids)[0]
        return assistant_message.split("[/INST]")[1].strip().split("</s>")[0]
    else:
        assistant_message = tokenizer.batch_decode(generated_ids)
        answers = [ans.split("[/INST]")[1].strip().split("</s>")[0] for ans in assistant_message]
        return answers

# Load data
import json

with open("data/cnn_valid_50.json") as file:
    valid_data = json.load(file)

with open("data/cnn_test_200.json") as file:
    test_data = json.load(file)

PRE_DEFINED_ASSESSMENT_METRICS = [
    # ABCs of clear communication: Accuracy, Brevity, Clarity
    "Accuracy",
    "Brevity",
    "Clarity",
    # BARTScore paper: Relevance, Coherence
    "Relevance",
    "Coherence",
    # GPTScore paper: Semantic Coverage, Factuality, Consistency, Informativeness, Consistency, Fluency
    "Semantic Coverage",
    "Factuality",
    "Fluency",
    "Informativeness",
    "Consistency",
    "Engagement",
    "Specificity",
    "Correctness",
    "Understandability",
    "Diversity",
    # We self derive
    "Completeness",
    "Conciseness",
    "Neutrality",
    "Naturalness",
    "Readability",
    "Creativity",
    "Rationalness",
    "Truthfulness",
    "Respect of Chronology",
    "Non-repetitiveness",
    "Indicative",
    "Resolution"
]

PRE_DEFINED_ASSESSMENT_METRICS.sort()
len(PRE_DEFINED_ASSESSMENT_METRICS)

# Task-specific Metrics
import random
import re
from tqdm import tqdm

TASK_NAME = "highlights summarization"

def remove_elements_in_list1(list1, list2):
    return [elem for elem in list2 if elem not in list1]

def convert_score(score):
    if score - int(score) >= 0.5: return int(score) + 1
    else: return int(score)

def get_metric_general_definition(metrics):
    input_prompt = f"""Define the list of following metrics in details as the quality of the generation expected for the {TASK_NAME} task.
{metrics}
Give me the list in bullet points.
"""
    return get_mistral_answer(input_prompt)

def step_1_collecting_metrics(valid_data, batch_size, num_iterations):
    iter_valid_data = valid_data
    collected_metrics = set()
    for _ in range(num_iterations):
        selected_valid = random.sample(iter_valid_data, batch_size)
        iter_valid_data = [elem for elem in iter_valid_data if elem not in selected_valid]

        DEMONSTRATION_STRING = ""
        for idx in range(len(selected_valid)):
            DEMONSTRATION_STRING += f"""INPUT {idx+1}: {str(valid_data[idx]["article"])}\nOUTPUT {idx+1}: {str(valid_data[idx]["highlights"])}\n\n"""

        collecting_metrics_prompt = f"""Select top-5 metrics which are the most important from the list below to evaluate a a special way of {TASK_NAME}.
{str(PRE_DEFINED_ASSESSMENT_METRICS)}

Here are some demonstrations of the task {TASK_NAME}:
{DEMONSTRATION_STRING}

Output your list of metrics in Python list format without any explanation: [...].
"""

        list_metrics = get_mistral_answer(collecting_metrics_prompt)
        try: collected_metrics.update(eval(list_metrics))
        except: continue

    collected_metrics = list(collected_metrics)
    collected_metrics.sort()
    return collected_metrics


def step_2_collecting_task_based_scores(valid_data, collected_metrics):
    EVALUATION_FORMAT = {key: "1-5" for key in collected_metrics}
    METRICS_DEFINITIONS = get_metric_general_definition(collected_metrics)
    print(METRICS_DEFINITIONS)

    MC_collected_scores = {}
    for metric in collected_metrics: MC_collected_scores[metric] = 0

    random.shuffle(valid_data)
    valid_update_cnt = 0
    for dt in tqdm(valid_data[:25]):
        input = str(dt["article"])
        target = str(dt["highlights"])

        # Evaluation step
        evaluation_prompt = f"""You are given an input, and an output of a {TASK_NAME} task.
Input: {input}
Output: {target}

Your task is to evaluate the following criteria in a scale of 1-5, with 1 is worst and 5 is best.
{EVALUATION_FORMAT}

The definitions of the criteria are:
{METRICS_DEFINITIONS}

Your output must be in Python dictionary format without any explanation.
"""
        evaluation_outcomes = get_mistral_answer(evaluation_prompt, n=3)
        try:
            for eva_outcome in evaluation_outcomes:
                eva_outcome = eval(eva_outcome)
                for metric in eva_outcome:
                    MC_collected_scores[metric] += eva_outcome[metric]
                valid_update_cnt += 1
        except: continue


    if valid_update_cnt >= 1:
        for metric in MC_collected_scores:
            MC_collected_scores[metric] = convert_score(MC_collected_scores[metric]/valid_update_cnt)
    else:
        for metric in MC_collected_scores:
            MC_collected_scores[metric] = 5

    return MC_collected_scores

def getting_task_metrics(valid_data, batch_size, num_iterations):

    # Step 1: Collecting metrics
    collected_metrics = step_1_collecting_metrics(valid_data, batch_size, num_iterations)
    print("Getting task metrics: Finish step 1, collecting metrics.")
    print(f"Collected metrics: {str(collected_metrics)}")

    # Step 2: Collecting metrics' scores
    MC_collected_scores = step_2_collecting_task_based_scores(valid_data, collected_metrics)
    print("Getting task metrics: Finish step 2, collecting metrics' scores.")

    # Step 3: Collecting the metrics' definitions
    metrics_string = ", ".join(list(collected_metrics))
    collecting_definitions_prompt = f"""Now you are given the following metrics: {metrics_string} for the {TASK_NAME} task.
Based on these scores on a scale of 5 for the quality of the highlighted summary: {str(MC_collected_scores)}, define the expected quality of the highlights for each metric in natural language. Give me the list in bullet points."""

    raw_metrics_definitions = get_mistral_answer(collecting_definitions_prompt)
    if "\n\n" in raw_metrics_definitions: raw_metrics_definitions = raw_metrics_definitions.split("\n\n")[1].strip()
    return raw_metrics_definitions, list(collected_metrics), MC_collected_scores

# Linguistics metrics
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk import pos_tag
from collections import Counter
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
import math

def count_sentences(text):
    return len(sent_tokenize(text))

def count_words(text):
    words = word_tokenize(text)
    return len(words)

def count_verbs(text):
    words = word_tokenize(text)
    tagged_words = pos_tag(words)

    # Count the number of verbs (VB, VBD, VBG, VBN, VBP, VBZ)
    verb_count = sum(1 for word, pos in tagged_words if pos.startswith('VB'))
    return verb_count

def count_nouns(text):
    words = word_tokenize(text)
    tagged_words = pos_tag(words)

    # Count the number of nouns (NN, NNS, NNP, NNPS)
    noun_count = sum(1 for word, pos in tagged_words if pos.startswith('NN'))
    return noun_count

def get_majority_voice(paragraph):
    sentences = sent_tokenize(paragraph)

    # Combine part-of-speech tags of verbs in all sentences
    all_verbs = [pos_tag(word_tokenize(sentence)) for sentence in sentences]
    verb_tags = [pos for sentence_pos_tags in all_verbs for _, pos in sentence_pos_tags if pos.startswith('VB')]

    # Count the occurrences of each verb voice (active or passive)
    voice_counts = Counter(verb_tags)

    # Determine the majority voice
    majority_voice = max(voice_counts, key=voice_counts.get)
    return majority_voice

def count_nouns(text):
    words = word_tokenize(text)
    tagged_words = pos_tag(words)

    # Count the number of nouns (NN, NNS, NNP, NNPS)
    noun_count = sum(1 for word, pos in tagged_words if pos.startswith('NN'))
    return noun_count

def get_majority_voice(paragraph):
    sentences = sent_tokenize(paragraph)

    # Combine part-of-speech tags of verbs in all sentences
    all_verbs = [pos_tag(word_tokenize(sentence)) for sentence in sentences]
    verb_tags = [pos for sentence_pos_tags in all_verbs for _, pos in sentence_pos_tags if pos.startswith('VB')]

    # Count the occurrences of each verb voice (active or passive)
    voice_counts = Counter(verb_tags)

    # Determine the majority voice
    majority_voice = max(voice_counts, key=voice_counts.get)
    return majority_voice

def getting_task_linguistics_metrics(valid_data):
    words = []
    sentences = []
    for dt in valid_data:
        words.append(count_words(dt["instruction"]))
        sentences.append(count_sentences(dt["output"]))

    return {
        "min_words": min(words),
        "max_words": max(words),
        "ave_words": sum(words)/len(words),
        "min_sentences": min(sentences),
        "max_sentences": max(sentences),
        "ave_sentences": sum(sentences)/len(sentences)
    }

def consutruct_linguistics_metrics(dict_statistics):
    return f"""Your response must have from {math.floor(dict_statistics["min_sentences"])} to {math.floor(dict_statistics["max_sentences"])} sentences and from {math.floor(dict_statistics["min_words"])} to {math.floor(dict_statistics["max_words"])} words with an average of {math.floor(dict_statistics["ave_words"])} words and {math.floor(dict_statistics["ave_sentences"])} sentences."""

linguistics_statistics = getting_task_linguistics_metrics(valid_data)
print(linguistics_statistics)
consutruct_linguistics_metrics(linguistics_statistics)

def get_generated_output_and_ablations(input, METRICS_DEFINITIONS, linguistics_attributes):
    # Construct demonstrations
    tmp_valid_data = valid_data
    random.shuffle(tmp_valid_data)
    DEMONSTRATIONS = """"""
    NUM_DEMONSTRATIONS = 3

    for dt in tmp_valid_data[:NUM_DEMONSTRATIONS]:
        DEMONSTRATIONS += f"""Input: {str(dt["article"])}\nOutput: {dt["highlights"]}\n"""
    ######################## ZERO SHOT ##############################
    zero_shot_prompt = f"""Summarize the highlights from the following article :\n{input}"""
    zero_shot_answer = get_mistral_answer(zero_shot_prompt)
    ########################## FEW SHOT #####################################
    few_shot_prompt = f"""Summarize the highlights from the following article :\n{DEMONSTRATIONS} Input: {input}\nOutput:"""
    few_shot_answer = get_mistral_answer(few_shot_prompt)
    ########################### FULL ATTRIBUTES ############################
    full_attributes_prompt = f"""Summarize the highlights from the following article. Your generated summary must strictly fulfill the following task metrics.

{str(METRICS_DEFINITIONS)}

Input: {input}

{linguistics_attributes}"""
    full_attributes = get_mistral_answer(full_attributes_prompt)
    ########################### FULL ATTRIBUTES + FEW SHOT ############################
    full_attributes_few_shot_prompt = f"""Summarize the highlights from the following article. Your generated summary must strictly fulfill the following task metrics. {linguistics_attributes}

{str(METRICS_DEFINITIONS)}

\n{DEMONSTRATIONS}Input: {input}\nOutput:"""
    full_attributes_few_shot = get_mistral_answer(full_attributes_few_shot_prompt)
    ############################## ONLY LINGUISTICS ###########################
    only_linguistics_prompt = f"""Summarize the highlights from the following article.

Input: {input}

{linguistics_attributes}"""
    only_linguistics = get_mistral_answer(only_linguistics_prompt)

    ############################# ONLY TASK ATTRIBUTES ############################
    only_task_attributes_prompt = f"""Summarize the highlights from the following article. Your generated summary must strictly fulfill the following task metrics.

{str(METRICS_DEFINITIONS)}

Input: {input}"""
    only_task_attributes = get_mistral_answer(only_task_attributes_prompt)

    return zero_shot_answer, few_shot_answer, DEMONSTRATIONS, full_attributes, full_attributes_few_shot, only_linguistics, only_task_attributes

################ MAIN EXPERIMENT STEP 1 ################

from tqdm import tqdm

# Task and linguistics attribute learning
batch_size = 5
num_iterations = 5

print("######### Start learning task and linguistics metrics...")
METRICS_DEFINITIONS, collected_metrics, MC_collected_scores = getting_task_metrics(valid_data, batch_size, num_iterations)

print(collected_metrics)
print(METRICS_DEFINITIONS)
print(MC_collected_scores)

linguistics_statistics = getting_task_linguistics_metrics(valid_data)
linguistics_attributes = consutruct_linguistics_metrics(linguistics_statistics)
EVALUATION_FORMAT = {key: "1-5" for key in collected_metrics}
print("######### Finish learning task and linguistics attributes...")

################ MAIN EXPERIMENT STEP 2 ################
saved_data = []
for data in tqdm(test_data):
    gem_id = data["id"]
    input = data["article"]
    ground_truth_output = data["highlights"]
    # category = data["category"]

    # Get the generated output
    zero_shot_answer, few_shot_answer, DEMONSTRATIONS, generated_output, full_attributes_few_shot, only_linguistics, only_task_attributes = get_generated_output_and_ablations(input, METRICS_DEFINITIONS, linguistics_attributes)

    saved_data.append([
        gem_id, input, ground_truth_output,
        zero_shot_answer, few_shot_answer, DEMONSTRATIONS,
        str(collected_metrics), str(MC_collected_scores), str(METRICS_DEFINITIONS), str(linguistics_statistics),
        only_linguistics, only_task_attributes, generated_output, full_attributes_few_shot
    ])

# Save data
import csv

OUTPUT_PATH = "output/cnn_longguide.csv"

with open(OUTPUT_PATH, "w") as file:
    csvwriter = csv.writer(file)
    csvwriter.writerow([
        "id", "article", "highlights",
        "zero_shot_answer", "few_shot_answer", "demonstrations",
        "MC_task_metrics", "MC_task_scores", "task_metrics_definitions", "linguistics_statistics",
        "only_linguistics", "only_task_attributes", "full_attributes", "full_attributes_few_shot"
    ])
    csvwriter.writerows(saved_data)

# Evaluation
import csv 
from tqdm import tqdm

from rouge_score import rouge_scorer
def compute_rougeL(generated, reference):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    scores = scorer.score(generated, reference)
    return scores["rougeL"].fmeasure

zs_rouge = 0
few_rouge = 0
our_rouge = 0
our_few_rouge = 0
only_linguistics_rouge = 0
only_task_attributes_rouge = 0

cnt = 0
with open(OUTPUT_PATH) as file:
    csvreader = csv.reader(file)
    header = next(csvreader)
    for row in tqdm(csvreader):
        summary = row[list(header).index("highlights")]

        zs_rouge += compute_rougeL(row[list(header).index("zero_shot_answer")], summary)
        few_rouge += compute_rougeL(row[list(header).index("few_shot_answer")], summary)

        only_linguistics_rouge += compute_rougeL(row[list(header).index("only_linguistics")], summary)
        only_task_attributes_rouge += compute_rougeL(row[list(header).index("only_task_attributes")], summary)
        our_rouge += compute_rougeL(row[list(header).index("full_attributes")], summary)
        our_few_rouge += compute_rougeL(row[list(header).index("full_attributes_few_shot")], summary)
        cnt += 1

print()
print(f"ROUGE-L zs_rouge: {zs_rouge/cnt}")
print(f"ROUGE-L few_rouge: {few_rouge/cnt}")
print(f"ROUGE-L only_linguistics_rouge: {only_linguistics_rouge/cnt}")
print(f"ROUGE-L only_task_attributes_rouge: {only_task_attributes_rouge/cnt}")
print(f"ROUGE-L our_rouge: {our_rouge/cnt}")
print(f"ROUGE-L our_few_rouge: {our_few_rouge/cnt}")