import vllm
import random
import fishfarm
from datasets import load_dataset
from ai2_arc_task import Ai2ArcSample, Ai2ArcTask
from fishfarm.models.vllm_model import VLLMModel
from .base import Task, get_download_dir


choices = ["A", "B", "C", "D", "E"]
LLAMA3_COT = (
    "{% set loop_messages = messages %}"
    "{% for message in loop_messages %}"
    "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>"
    "\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
    "{% if loop.index0 == 0 %}{% set content = bos_token + content %}"
    "{% endif %}"
    "{{ content }}"
    "{% endfor %}"
    "{% if add_generation_prompt %}"
    "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + 'Let\\'s think step by step' }}"
    "{% endif %}"
) 


class AI2ArcTask(Task):
    def __init__(self,):
        # TODO mistral cot
        self.model_to_template = {
            'meta-llama/Meta-Llama-3-8B-Instruct': LLAMA3_COT,
            'mistralai/Mistral-7B-Instruct-v0.3': None,
        }
        self.system_msg = (
            "The following are multiple choice questions (with answers). "
            "Think step by step and then finish your answer "
            'with "the answer is (X)" where X is the correct letter choice.'
        )
        self.target_metric_train = "acc"
        self.target_metric_valid = self.target_metric_train
        self.target_metric_test = self.target_metric_train
        self.target_metric_transfer = self.target_metric_train
        self.has_transfer_split = True
        self.has_training_split = True


    def get_train_data(self,):
        train_eval, *test_evals = self.get_evaluator()
        train_data = train_eval.samples
        train_size = len(train_data)
        train_ix = range(0, train_size, 2)
        valid_ix = range(1, train_size, 2)
        return train_data, train_ix, valid_ix

    def get_rewards(self, res):
        rewards = [1. if x['correct'] else -1. for x in res.sample_details]
        return rewards
    
    def get_evaluator(self,):
        res = []
        for split in ['train', 'test']:
            dataset = load_dataset("allenai/ai2_arc", "ARC-Challenge", 
                                   split=split)
            dataset = load_dataset("allenai/ai2_arc", "ARC-Easy", split=split)
            samples = []
            for sample in dataset:
                options = []
                for opt in sample["choices"]['text']:
                    options.append(opt)
                # add options to the question
                question = sample["question"] + "\n"
                question += "Options:\n"
                for i, opt in enumerate(options):
                    question += "{}. {}\n".format(choices[i], opt)
                samples.append(
                    Ai2ArcSample(
                        question=question,
                        answer=sample["answerKey"],
                        options=options,
                        question_id=sample["id"],
                    )
                )    
            res.append(Ai2ArcTask(
                samples=samples,
                context_messages=[
                    fishfarm.Message("system", self.system_msg),
                ],
            ))
        dataset = load_dataset(
            "allenai/ai2_arc", "ARC-Challenge", split='test')
        samples = []
        for sample in dataset:
            options = []
            for opt in sample["choices"]['text']:
                options.append(opt)
            # add options to the question
            question = sample["question"] + "\n"
            question += "Options:\n"
            for i, opt in enumerate(options):
                question += "{}. {}\n".format(choices[i], opt)
            samples.append(
                Ai2ArcSample(
                    question=question,
                    answer=sample["answerKey"],
                    options=options,
                    question_id=sample["id"],
                )
            )    
        res.append(Ai2ArcTask(
            samples=samples,
            context_messages=[
                fishfarm.Message("system", self.system_msg),
            ],
        ))
        return tuple(res)
    
    def get_prompt(self, tokenizer, samples, ix, model_id):
        chat_template = self.model_to_template[model_id]
        context_msg = {'role': 'system', 'content': self.system_msg}
        user_msg = {'role': 'user', 'content': samples[ix].question}
        prompt = tokenizer.apply_chat_template(
            conversation=[context_msg, user_msg],
            chat_template=chat_template,
            tokenize=False,
            add_generation_prompt=True,
        )
        return prompt

    def get_vllm_model(self, model_id) -> VLLMModel:
        """Load a vLLM model."""
        model = vllm.LLM(
            model_id,
            max_model_len=1024,
            gpu_memory_utilization=0.8,
            enforce_eager=True,
            dtype='bfloat16',
            download_dir=get_download_dir(),
        )
        chat_template = self.model_to_template[model_id]
        # This may change with vLLM versions.
        m = model.llm_engine.model_executor.driver_worker.model_runner.model
        for _, param in m.named_parameters():
            param.requires_grad = False
        vllm_model = VLLMModel(
            model,
            sampling_params=vllm.SamplingParams(
                temperature=0,
                top_p=1,
                max_tokens=512,
                stop=["Instruction:", "Instruction", "Response:", "Response"],
                repetition_penalty=1.0,
            ),
            chat_template=chat_template,
        )
        return vllm_model


