import json
from pathlib import Path
import string
import ollama
from data import snli_labels
from pydantic import BaseModel
import random
import util

def _stream(model, messages, do_print=True, **kwargs):

    stream = ollama.chat(
        model=model,
        messages=messages,
        stream=True,
        **kwargs
    )

    system_response = ''
    if do_print:
        print("assistant: ", end='')
    for chunk in stream:
        response_chunk = chunk['message']['content']
        system_response += response_chunk
        if do_print:
            print(response_chunk, end='', flush=True)
    if do_print:
        print()

    return system_response

def generate_one(model, message_history, user_message, print_history=True):
    if print_history:
        for m in message_history:
            util.print_message(m)
    util.print_message(user_message)

    return _stream(model, message_history + [user_message])

def format_user_prompt(premise):
    return {
        "role": "user",
        "content": f"Premise: {premise}"
    }

def format_few_shot_example(ex):
    user = format_user_prompt(ex['sentence1'])
    hypotheses = {l: ex[l]['sentence2'] for l in snli_labels}
    model = {
        "role": "assistant",
        "content": answer_template.format(**hypotheses)
    }
    return [user, model]

answer_template = """
{{
    "contradiction": "{contradiction}",
    "entailment": "{entailment}",
    "neutral": "{neutral}"
}}"""

class AnswerFormat(BaseModel):
    contradiction: str
    entailment: str
    neutral: str

if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('model', choices=util.get_available_models())
    parser.add_argument('snli_dir', type=Path)
    parser.add_argument('snli_split', type=str)
    parser.add_argument('output_dir', type=Path)
    parser.add_argument('--num-samples', type=int, default=None)
    parser.add_argument('--use-samples-from', type=Path, default=None)
    parser.add_argument('--system-prompt', type=Path, default='prompts/generate-nli-prompt.txt')
    parser.add_argument('--few-shot-examples', type=Path, default='prompts/perfect-snli-examples.json')
    args = parser.parse_args()

    model_info = ollama.show(args.model)['model_info']

    args.output_dir.mkdir(parents=True, exist_ok=True)
    output_file = args.output_dir/f'{args.model}_{args.snli_split}.jsonl'
    metadata_file = args.output_dir/f'{args.model}_{args.snli_split}_meta.json'

    examples = json.load(args.few_shot_examples.open())
    example_ids = [item['captionID'] for item in examples]
    example_messages = []
    for ex in examples:
        example_messages += format_few_shot_example(ex)

    system_prompt = {'role': 'system', 'content': args.system_prompt.open().read()}
    messages = [system_prompt] + example_messages
    answer_format = AnswerFormat.model_json_schema()

    from IPython import embed; embed(); raise

    if output_file.exists():
        output_data = list(map(json.loads, output_file.open().readlines()))
    else:
        output_data = []
    already_done_ids = [item['premiseID'] for item in output_data]

    metadata = {
        'model_info': ollama.show(args.model)['model_info'],
        'few_shot_examples': examples,
        'system_prompt': system_prompt,
        'snli_split': args.snli_split
    }

    if metadata_file.exists():
        existing_metadata = json.load(metadata_file.open())
        if not existing_metadata == metadata:
            raise ValueError("Metadata mismatch with previous run.")
    else:
        with metadata_file.open('w') as f:
            json.dump(metadata, f)

    snli_file = args.snli_dir/f"snli_1.0_{args.snli_split}.jsonl"
    _input_filter = lambda item: not (
           item['captionID'] in example_ids    # exclude items with premices used in examples
        or item['pairID'] in already_done_ids  # exclude items we already generated from
    )
    if args.use_samples_from:
        prev_samples = util.load_jsonl(args.use_samples_from)
        prev_ids = {item['premiseID'] for item in prev_samples}
        input_filter = lambda x: _input_filter(x) and x['pairID'] in prev_ids
    else:
        input_filter = _input_filter

    with snli_file.open() as f:
        input_data = list(filter(input_filter, map(json.loads, f.readlines())))
    random.shuffle(input_data)
    input_data = input_data[:args.num_samples]
    input_data_len = len(input_data)


    with output_file.open('a') as f:
        for i, item in enumerate(input_data):
            print(f"{'-'*15} {i+1}/{input_data_len} {'-'*15}\n", flush=True)
            user_message = format_user_prompt(item['sentence2'])
            if i == 0:
                for m in messages:
                    util.print_message(m)
            util.print_message(user_message)
            response = _stream(args.model, messages + [user_message], format=answer_format)
            response = json.loads(response)

            for label in snli_labels:
                gen_item_label = item['pairID'] + f"#{args.model}_{label[0]}"

                out = {
                    'captionID': item['captionID'], # ID of original premise
                    'premiseID': item['pairID'],    # ID of new premise (=ID of original pair)
                    'pairID': gen_item_label,       # ID for the new generated item
                    'model_label': label,           # label the model was told to generate for
                    'sentence1': item['sentence2'], # premise for the generated example
                    'sentence2': response[label],   # generated hypothsesis
                }
                util.write_jsonl(out, f)
        
