import json
import random
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional

import hydra
from omegaconf import OmegaConf
from tqdm import tqdm

import lm_understanding.prompting.text as text_utils
from lm_understanding.data import template_path
from lm_understanding.models import load_model
from lm_understanding.models.gpt import ChatCompleter, message
from lm_understanding.question_template import (QuestionTemplate,
                                                answer_template_questions,
                                                measure_template_diversity)


@dataclass
class FewShotExample:
    instructions: List[str]
    template: QuestionTemplate


def load_few_shot_examples(path: Path):
    data = json.loads(path.read_text())['examples']
    return [FewShotExample(d['instructions'], QuestionTemplate.from_dict(d)) for d in data]


def template_example(template: QuestionTemplate) -> str:
    template_dict = asdict(template)
    template_dict.pop('template_id')
    template_dict = dict(templates=[template_dict])
    return json.dumps(template_dict, ensure_ascii=False, indent=4)
 

TEMPLATE_INSTRUCTIONS = """- The template should be a paragraph that ends with a simple yes or no question.
- Try to generate diverse and creative templates.
- Five factors of the paragraph that significantly influence the answer to the final question have been replaced by variables (a, b, c, d, and e).
- The template should contain only the factors: [a], [b], [c], [d], and [e]. DO NOT include [f] onward.
- If the same factor appears multiple times in the paragraph, replace each instance with the appropriate bracketed letter.
- Each factor should have exactly 15 possible values.
- The values should be diverse and interesting.
- The values should significantly influence the final answer."""


def last_message_text(task_instructions, task_premises, templates_per_prompt, example_template_text):
    text = f"""Create {templates_per_prompt} templates that meet the following criteria:
{text_utils.unordered_list(task_instructions)}

Here are some examples of the general type of question I'm thinking about:
{text_utils.unordered_list(task_premises)}"""
    if example_template_text:
        text += f'\n\nHere is an example template:\n{example_template_text}'
    text += "\n\nFollow the same GENERAL INSTRUCTIONS as the previous templates."
    return text


def make_prompt(task_instructions: List[str], task_premises: List[str], few_shot_examples: List[FewShotExample], templates_per_prompt: int, example_template_text: Optional[str] = None) -> List[Dict[str, str]]:
        messages = [
            message('system', f"You are a helpful assistant. Your task is to create paragraph templates given a set of instructions about the type of template to create. Return the template(s) formatted as JSON."),
            message('user', f"For each of the following template requests, use the following GENERAL INSTRUCTIONS:\n{TEMPLATE_INSTRUCTIONS}"),
            message('assistant', f"Confirmed. I will follow these instructions for each template I create.")
        ]
        for fse in few_shot_examples:
            messages.extend((
                message('user', f"Create 1 template that meets the following criteria:\n{text_utils.unordered_list(fse.instructions)}"),
                message('assistant', template_example(fse.template)),
            ))
        messages.append(message('user', last_message_text(task_instructions, task_premises, templates_per_prompt, example_template_text)))
        return messages


def parse_completion(completion_text) -> List[QuestionTemplate]:
    try:
        data = json.loads(completion_text)
    except:
        return []
    templates = []
    for d in data['templates']:
        try:
            templates.append(QuestionTemplate.from_dict(d))
        except:
            print('Not a valid template:')
            print(d)
    return [t for t in templates if t.valid]


def measure_positive_answer_fractions(templates: List[QuestionTemplate], config):
    completer, prompter = load_model(config.filter.model)
    templates_to_filter = templates[:config.filter.max_templates]
    filter_results = [
        answer_template_questions(template, template.variables, completer, n_questions=config.filter.samples, prompter=prompter, seed=config.seed, sample_fast=True)
        for template in tqdm(templates_to_filter, desc=f'Evaluating Templates: {completer.name}')
    ]
    return [r.positive_fraction for r in filter_results]


def save_templates(templates, save_dir, task_name, run_info):
    save_dir.mkdir(exist_ok=True, parents=True)
    with open(save_dir / 'question_templates.json', 'w') as f:
        json.dump(dict(questions=[asdict(t) for t in templates]), f, indent=4, ensure_ascii=False)


def save_results(templates: List[QuestionTemplate], config, chat_gpt: ChatCompleter, save_dir: Path):
    save_dir.mkdir(exist_ok=True, parents=True)
    template_diversity=measure_template_diversity(templates)
    valid_fraction = len(templates) / (config.prompts * config.templates_per_prompt * config.completions_per_prompt)

    run_info: Dict[str, Any] = dict(
        completer=str(chat_gpt),
        completer_info=chat_gpt.info,
        template_diversity=template_diversity,
        valid_fraction=valid_fraction
    )
    save_templates(templates, save_dir, config.task.name, run_info)

    run_info.update(dict(config=OmegaConf.to_container(config))) # type: ignore

    with open(save_dir / 'run_info.json', 'w') as f:
        json.dump(run_info, f, indent=4, ensure_ascii=False)


@hydra.main(config_path='config', config_name='create_templates.yaml', version_base='1.2')
def main(config):
    task_name = config.task.name
    data_path = Path('examples')
    task_data = json.loads((data_path / 'task_data.json').read_text())['tasks'][task_name]
    few_shot_examples = load_few_shot_examples(data_path / 'few_shot_examples.json')[:config.few_shot_examples]
    save_dir = template_path(config, new=True)
    save_dir.mkdir(exist_ok=True, parents=True)

    def log_temp_completions(completions: List[str]) -> None:
        tmp_save_dir: Path = save_dir / 'tmp'
        tmp_save_dir.mkdir(exist_ok=True, parents=True)
        (tmp_save_dir / 'completions.txt').write_text('\n'.join(completions))
        templates = []
        for completion in completions:
            templates.extend(parse_completion(completion))
        save_templates(templates, tmp_save_dir, task_name, dict())

    chat_gpt = ChatCompleter(model_name='gpt-4', temperature=config.completer_temperature, max_tokens=4096, n=config.completions_per_prompt, log_fn=log_temp_completions, verbose=True)

    prompts = []
    for i in range(config.prompts):
        premises = random.sample(task_data['example_premises'], config.example_premises)
        example_template_text = task_data.get('example_template')
        prompt = make_prompt(task_data['instructions'], premises, few_shot_examples, config.templates_per_prompt, example_template_text)
        prompts.append(prompt)
    completions = chat_gpt(prompts)
    templates = []
    for completion in completions:
        templates.extend(parse_completion(completion.text))

    for i, t in enumerate(templates):
        t.template_id = f'{task_name}_{i}'

    save_results(templates, config, chat_gpt, save_dir)


if __name__ == '__main__':
    main()