import json
import os
import time
import dotenv
import traceback
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)
from datasets import load_from_disk, concatenate_datasets
from utils import extract_diff, string_to_bool
from argparse import ArgumentParser
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
dotenv.load_dotenv()


MODEL_LIMITS = {
    'gpt-3.5-turbo-16k-0613': 16000,
    'gpt-4-32k-0613': 31000,
}


def openai_inference(
    test_dataset,
    model_name_or_path,
    output_file,
    model_args,
    existing_ids,
):
    import openai

    openai.api_key = os.environ.get("OPENAI_API_KEY", None)

    @retry(wait=wait_random_exponential(min=60, max=600), stop=stop_after_attempt(6))
    def call_chat(model_name_or_path, inputs):
        system_messages = inputs.split("\n", 1)[0]
        user_message = inputs.split("\n", 1)[1]
        try:
            response = openai.ChatCompletion.create(
                model=model_name_or_path,
                messages=[
                    {"role": "system", "content": system_messages},
                    {"role": "user", "content": user_message},
                ],
                temperature=0.7,
                top_p=0.95,
            )
            return response.choices[0]['message']['content']
        except openai.error.InvalidRequestError as e:
            if e.code == 'context_length_exceeded':
                print("Context length exceeded")
                return None
            raise e

    basic_args = {
        "model_name_or_path": model_name_or_path,
    }
    with open(output_file, "a+") as f:
        for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
            instance_id = datum["instance_id"]
            if instance_id in existing_ids:
                continue
            output_dict = {"instance_id": instance_id}
            output_dict.update(basic_args)
            output_dict["text"] = f"{datum['text']}\n\n"
            if len(datum['input_ids']) > MODEL_LIMITS[model_name_or_path]:
                output_dict["full_output"] = None
                output_dict["model_patch"] = None
            else:
                if model_name_or_path == 'gpt-4-32k-0613' and len(datum['input_ids']) <= 6000:
                    completion = call_chat('gpt-4-0613', output_dict["text"])
                else:
                    completion = call_chat(
                        output_dict["model_name_or_path"], output_dict["text"]
                    )
                output_dict["full_output"] = completion
                output_dict["model_patch"] = extract_diff(completion)
            print(json.dumps(output_dict), file=f, flush=True)


def anthropic_inference(
    test_dataset,
    model_name_or_path,
    output_file,
    model_args,
    existing_ids,
):
    from anthropic import HUMAN_PROMPT, AI_PROMPT, Anthropic

    api_key = model_args.pop(
        "anthropic_api_key", os.environ.get("ANTHROPIC_API_KEY", None)
    )

    @retry(wait=wait_random_exponential(min=60, max=600), stop=stop_after_attempt(6))
    def anthropic_call(inputs):
        anthropic = Anthropic(api_key=api_key)
        try:
            completion = anthropic.completions.create(
                model=model_name_or_path,
                max_tokens_to_sample=6000,
                prompt=inputs,
                **model_args,
            )
            return completion.completion
        except Exception as e:
            logger.error(e)
            logger.error(f"Inputs: {inputs}")
            traceback.print_exc()
            time.sleep(20)
            return None

    basic_args = {
        "model_name_or_path": model_name_or_path,
    }
    with open(output_file, "a+") as f:
        for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
            instance_id = datum["instance_id"]
            if instance_id in existing_ids:
                continue
            output_dict = {"instance_id": instance_id}
            output_dict.update(basic_args)
            output_dict["text_inputs"] = f"{HUMAN_PROMPT} {datum['text']}\n\n{AI_PROMPT}"
            completion = anthropic_call(output_dict["text_inputs"])
            output_dict["full_output"] = completion
            output_dict["model_patch"] = extract_diff(completion)
            print(json.dumps(output_dict), file=f, flush=True)


def parse_model_args(model_args):
    kwargs = dict()
    if model_args is not None:
        for arg in model_args.split(","):
            key, value = arg.split("=")
            # infer value type
            if value in {"True", "False"}:
                kwargs[key] = value == "True"
            elif value.isnumeric():
                kwargs[key] = int(value)
            elif value.replace(".", "", 1).isnumeric():
                kwargs[key] = float(value)
            elif value in {"None"}:
                kwargs[key] = None
            elif value in {"[]"}:
                kwargs[key] = []
            elif value in {"{}"}:
                kwargs[key] = {}
            elif value.startswith("'") and value.endswith("'"):
                kwargs[key] = value[1:-1]
            elif value.startswith('"') and value.endswith('"'):
                kwargs[key] = value[1:-1]
            else:
                kwargs[key] = value
    return kwargs


def main(
    dataset_path,
    model_name_or_path,
    shard_id,
    num_shards,
    output_dir,
    model_args,
):
    if shard_id is None and num_shards is not None:
        logger.warning(f"Received num_shards={num_shards} but shard_id is None, ignoring")
    if shard_id is not None and num_shards is None:
        logger.warning(f"Received shard_id={shard_id} but num_shards is None, ignoring")
    model_args = parse_model_args(model_args)
    model_nickname = model_name_or_path
    if "checkpoint" in Path(model_name_or_path).name:
        model_nickname = Path(model_name_or_path).parent.name
    else:
        model_nickname = Path(model_name_or_path).name
    output_file = f"{model_nickname}__{Path(dataset_path).name}"
    if shard_id is not None and num_shards is not None:
        output_file += f"__shard-{shard_id}__num_shards-{num_shards}"
    output_file = Path(output_dir, output_file + ".jsonl")
    logger.info(f"Will write to {output_file}")
    existing_ids = set()
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            for line in f:
                data = json.loads(line)
                instance_id = data["instance_id"]
                existing_ids.add(instance_id)
    logger.info(f"Read {len(existing_ids)} already completed ids")
    dataset = load_from_disk(dataset_path)
    load_splits = [split for split in dataset.keys() if 'test' in split]
    dataset = concatenate_datasets([dataset[split] for split in load_splits])
    lens = np.array(list(map(lambda x: len(x['text'].split()), dataset)))
    dataset = dataset.select(np.argsort(lens))
    if len(existing_ids) > 0:
        dataset = dataset.filter(lambda x: x['instance_id'] not in existing_ids)
    if shard_id is not None and num_shards is not None:
        dataset = dataset.shard(num_shards, shard_id, contiguous=True)
    inference_args = {
        "test_dataset": dataset,
        "model_name_or_path": model_name_or_path,
        "output_file": output_file,
        "model_args": model_args,
        "existing_ids": existing_ids,
    }
    if model_name_or_path in {"claude-2"}:
        anthropic_inference(**inference_args)
    elif model_name_or_path.startswith("gpt"):
        openai_inference(**inference_args)
    else:
        raise ValueError(f"Invalid model name or path {model_name_or_path}")
    logger.info(f"Done!")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
        help="Path to the directory containing a datasets dataset",
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to the directory containing a lora or model",
    )
    parser.add_argument(
        "--shard_id",
        type=int,
        default=None,
        help="Shard id to process. If None, process all shards.",
    )
    parser.add_argument(
        "--num_shards",
        type=int,
        default=None,
        help="Number of shards. If None, process all shards.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        required=True,
        help="Path to the output file.",
    )
    parser.add_argument(
        "--model_args",
        type=str,
        default=None,
        help="List of model arguments separated by commas. (e.g. 'top_p=0.95,temperature=0.80s')",
    )
    args = parser.parse_args()
    main(**vars(args))
