import argparse
import os
import sys
import traceback
import xml.etree.ElementTree as ET
import yaml

import boto3
import json
import numpy as np
import pandas as pd
import torch
from botocore.config import Config
from joblib import Parallel, delayed
from transformers import AutoTokenizer
from tqdm import tqdm

sys.path.append(os.getcwd())

from src.utils.model_loading import load_trained_reward_model


def extract_preference(text):
    try: 
        return int(text) - 1
    except:
        return -1

# TODO check if other regions are already supported, e.g. ap-southeast-1
SUPPORTED_REGIONS = {
    "anthropic.claude-v2": ["us-east-1", "us-west-2", "eu-central-1"], 
    "anthropic.claude-instant-v1": ["us-east-1", "us-west-2", "eu-central-1", "ap-northeast-1"]
}

def annotate(df: pd.DataFrame, batch_size=2, n_jobs=8, aws_region=None, modelId="anthropic.claude-instant-v1"):
    new_df = df.copy()
    def annotate_batch(rows, bedrock):
        prompts = "\n".join([
        f"""
<doc_{i + 3}>
{row['prompt']}
</doc_{i + 3}>

<doc_{i + 3}_summary_1>
{row['label']}
</doc_{i + 3}_summary_1>

<doc_{i + 3}_summary_2>
{row['output']}
</doc_{i + 3}_summary_2>"""
        for i, row in rows.reset_index(drop=True).iterrows()
    ])
        body = json.dumps({
            "prompt": f"""\n\nHuman: You are a helpful assistant that selects the best summary out of two answers. The summary is good if it is accurate, coherent and covers the most important parts of the text. The summaries are presented in a random order. Write a response with the number that corresponds to a better summary without any additional text. For example, <doc_1_chosen_summary>2</doc_1_chosen_summary> means that for the first document, the second summary is better while <doc_2_chosen_summary>1</doc_2_chosen_summary> means that for the second document, the first summary is better.

<doc_1>
Subreddit: r/relationships 
TITLE: Screwed up with boss... what should I do? 
POST: I'm 20 f, my boss is around 50 years old, also f. So I have two jobs, and the schedules for both jobs are made on a weekly basis. One of my jobs I have had for three years, the other one I have had for a month and a bit. I forgot to give my schedule from one job to my boss at my other job, and so I was not scheduled for this week. I didn't realize why I had not been put on the schedule until now. My question is, since I royally screwed up, what can I do to redeem myself? I don't want to call my boss today because it is a Sunday and she has the day off. Mistakes aren't easily forgiven where I work, as far as I can tell, and the boss often makes comments about how the employees should be scared of her. I have screwed up at previous jobs (little things) but my boss was less intimidating than my current one, so I am not sure how to handle this situation.
TL;DR: 
</doc_1>

<doc_1_summary_1>
screwed up at work by not giving the boss my schedule from my other job, am not scheduled this week, what should I say in order to apologize to my (scary/intimidating) boss?
</doc_1_summary_1>

<doc_1_summary_2>
Screwed up with boss... what should I do?
</doc_1_summary_2>

<doc_2>
Subreddit: r/relationships 
TITLE: I am a [18 M] she is a [17 F] and I don't know how to read this relationship?
POST: We've known each other through school but never talked until we had a class together. I asked her out after about a week, we went out for food, laughed, flirted etc etc. I asked her out again same situation, everything going great. Now it's three weeks later (midst of exams) and I'm starting to get the feeling that she's not thinking of this as a "relationship" in the conventional bf/gf sense. I'm new to the whole dating game and wondering a few things. Do I need to ask if we're together as bf/gf or is that implied by asking her out multiple times? Should I straight up ask if she likes me the way I like her? I know what the friend zone is and I'm not feeling there, yet, but we're having a movie night tomorrow just her and I. I plan on seeing what's what then by trying to get physical, not hardcore, just simple stuff like leg touches, cuddling etc. Oh and finally, to add to my lengthy situation, On our third "date" (studying for an exam) I drove her home and attempted to kiss her but got the infamous cheek. I didn't make a big deal about it I just moved past said "got your things?" and politely said goodbye preceding to wait for her to get inside her house. I've been told I shouldn't fret about ONE kiss rejection cause there could be multiple reasons for it but at the same time that made me think more about the friend zone possibility. Any advice or similar situations and how you solved the problem would be smashing! Thanks in advance.
TL;DR: 
</doc_2>

<doc_2_summary_1>
We've known each other through school but never talked until we had a class together. I asked her out after about a week, we went out for food, laughed, flirted etc etc.
</doc_2_summary_1>

<doc_2_summary_2>
Been on three dates with a girl, getting the feeling she's not interested in a relationship in the traditional sense. Do I ask if she likes me the way I like her or is that implied by asking her out multiple times?
</doc_2_summary_2>

Assistant:
<doc_1_chosen_summary>1</doc_1_chosen_summary>
<doc_2_chosen_summary>2</doc_2_chosen_summary>

Human:
{prompts}

Assistant:
<doc_3_chosen_summary>""",
            "max_tokens_to_sample": 300,
            "temperature": 0,
            "top_p": 1,
            "stop_sequences": [f"</doc_{len(rows) + 2}_chosen_summary>"]
        })

        accept = "application/json"
        contentType = "application/json"

        response = bedrock.invoke_model(body=body, modelId=modelId, accept=accept, contentType=contentType)

        response_body = "<doc_3_chosen_summary>" + json.loads(response.get("body").read()).get("completion") + f"</doc_{len(rows) + 2}_chosen_summary>"
        try:    
            root = ET.fromstring("<root>" + response_body + "</root>")
            preferences = [extract_preference(x.text) for x in root]
            if len(preferences) != len(rows):
                raise ValueError("Corrupted output from Claude")
        except:
            preferences = [-1] * len(rows)
        new_df.loc[rows.index, "claude_preference"] = preferences

    config = Config(retries={"max_attempts": 1000000, "mode": "adaptive"}, connect_timeout=300, read_timeout=300)
    if aws_region:
        clients = [boto3.client(service_name="bedrock-runtime", config=config, region_name=aws_region)]
    else:
        clients = [boto3.client(service_name="bedrock-runtime", config=config, region_name=region) for region in SUPPORTED_REGIONS[modelId]]
    try:
        Parallel(n_jobs=n_jobs, prefer="threads")(delayed(annotate_batch)(rows, clients[i % len(clients)]) for i, rows in tqdm(df.groupby(np.arange(len(df)) // batch_size)))
    except:
        traceback.print_exc()
        print("Error while sending the request to Claude, returning intermediate outputs")
    return new_df


def gold_reward_model_score(dataset: pd.DataFrame, model_directory: str):
    DATA_DIR = os.getenv("DATA_DIR", ".")
    DATASET = os.getenv("DATASET", "tldr")    
    devices = torch.cuda.device_count()
    rm_adapter_path = os.path.join(DATA_DIR, "data/models", DATASET, "reward-models", model_directory)
    sft_model_path = os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", model_directory)

    def annotate_batch(dataset, device):
        tqdm.pandas(desc=f"{device}: Gold reward model annotations")
        tokenizer = AutoTokenizer.from_pretrained(sft_model_path)
        model = load_trained_reward_model(rm_adapter_path, sft_model_path, device, tokenizer.pad_token_id)
        model.eval()

        def gold_rm_score(row, target_name="output"):
            output_tokens = tokenizer(row["prompt"].strip() + " " + row[target_name].strip(), return_tensors="pt")["input_ids"].to(device)
            return model(output_tokens).logits.item()

        with torch.no_grad():
            dataset["gold_rm_score_output"] = dataset.progress_apply(gold_rm_score, axis=1, target_name="output")
            dataset["gold_rm_score_label"] = dataset.progress_apply(gold_rm_score, axis=1, target_name="label")
        dataset["gold_rm_preference"] = ((dataset.gold_rm_score_output - dataset.gold_rm_score_label) > 0).astype(int)
        dataset["gold_rm_preference"] = dataset.gold_rm_preference - ((dataset.gold_rm_score_output - dataset.gold_rm_score_label) == 0).astype(int)
        return dataset
    
    datasets = Parallel(n_jobs=devices)(delayed(annotate_batch)(dataset[i::devices].copy(), device=f"cuda:{i}") for i in range(devices))
    return pd.concat(datasets).sort_index()


if __name__ == "__main__":
    DATA_DIR = os.getenv("DATA_DIR", ".")
    DATASET = os.getenv("DATASET", "tldr")    
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", "-d", type=str, default=None, help="Dataset to annotate, must have columns 'prompt', 'label', and 'output'")
    parser.add_argument("--batch_size", "-b", type=int, default=2, help="Number of samples to annotate in a single prompt")
    parser.add_argument("--n_jobs", "-n", type=int, default=8, help="Number of jobs to parallelize annotation requests")
    parser.add_argument("--model_id", "-m", type=str, default="anthropic.claude-instant-v1", help="Model ID as listed in AWS Bedrock", choices=["anthropic.claude-v2", "anthropic.claude-instant-v1"])
    parser.add_argument("--gold_rm", "-g", type=str, default="t", help="Annotate with Gold RM instead")
    args = parser.parse_args()

    if args.dataset_path is None:
        CONFIG_DIR = os.getenv("CONFIG_DIR", "gpt2")
        sft_config = yaml.load(open(os.path.join("configs", CONFIG_DIR, "sft.yaml")), yaml.Loader)
        args.dataset_path = os.path.join(DATA_DIR, "data/datasets", DATASET, sft_config["output_directory"], "outputs.json")

    try:
        df = pd.read_json(args.dataset_path)
    except:
        df = pd.read_json(args.dataset_path, lines=True)
    if "t" in args.gold_rm:
        CONFIG_DIR = os.getenv("CONFIG_DIR", "gpt2")
        gold_rm_directory = yaml.load(open(os.path.join("configs", CONFIG_DIR, "ppo.yaml")), yaml.Loader)["gold_rm_directory"]
        df = gold_reward_model_score(df, gold_rm_directory)
    else:    
        if "claude_preference" not in df:
            df = annotate(df, batch_size=args.batch_size, n_jobs=args.n_jobs, modelId=args.model_id)
        elif df["claude_preference"].isna().any():
            annotated_df = annotate(df[df["claude_preference"].isna()], batch_size=args.batch_size, n_jobs=args.n_jobs, modelId=args.model_id)
            df.update(annotated_df)
        else:
            print("All samples are already annotated")
    df.to_json(os.path.join(os.path.dirname(args.dataset_path), "annotated_outputs.json"), orient="records")
