import json
import base64
from io import BytesIO
from pathlib import Path
from typing import Iterable, Generator, Any, Optional

from PIL import Image
import torch
import ray
from qwen_vl_utils import process_vision_info
from ray.util import ActorPool
from tqdm import tqdm

from generation_utils import get_logits_processor


@ray.remote(num_gpus=1)
class Inferer:
    def __init__(self, args):
        import sys

        sys.path.append("../")

        from transformers import (
            AutoProcessor,
            Qwen2_5_VLForConditionalGeneration,
            Qwen2VLForConditionalGeneration,
        )
        from model.qwen2_5_vl import (
            Qwen2_5_VL_PGNForConditionalGeneration,
            Qwen2_5_VLPointerProcessor,
        )

        if args.model_is_pgn:
            kwargs = dict(
                max_pixels=1605632,
            )
            if args.resize_max_pixels:
                kwargs = dict(
                    max_pixels=672**2,
                )

        PROCESSOR_CLASS = (
            Qwen2_5_VLPointerProcessor if args.model_is_pgn else AutoProcessor
        )
        MODEL_CLASS = (
            Qwen2_5_VL_PGNForConditionalGeneration
            if args.model_is_pgn
            else Qwen2_5_VLForConditionalGeneration
        )

        model_kwargs = dict()
        if not args.model_is_pgn:
            if "2.5" in args.model:
                MODEL_CLASS = Qwen2_5_VLForConditionalGeneration
            else:
                MODEL_CLASS = Qwen2VLForConditionalGeneration
        else:
            model_kwargs["do_copy"] = args.do_copy

        processor = PROCESSOR_CLASS.from_pretrained(
            args.model,
            **kwargs,
            padding_side="left",
        )
        model = MODEL_CLASS.from_pretrained(
            args.model,
            device_map="auto",
            torch_dtype=torch.float16,
            attn_implementation="flash_attention_2",
            **model_kwargs,
        )
        model.eval()
        self.processor = processor
        self.processor.tokenizer.padding_side = "left"
        self.model = model

        # TODO: remove
        self.template = {
            "base": "{}\nPlease answer the question using a long-chain reasoning style and think step by step.",
            "none": "{}",
            "v1": "{}\nPlease answer the question clearly without reflection. Pay close attention to the content of the detected objects and do not guess their content.",
            "v2": "{}\nPlease answer the question without reflection and use short and non repetitive step-by-step reasoning.",
            "v3": "{}\nPlease answer the question. Use concise, non-repetitive reasoning without self-reflection.",
            "qwen2": "{}\nPlease try to answer the question with short words or phrases if possible.",
        }[args.template_type]
        self.max_new_tokens = args.max_new_tokens
        self.max_image_size = args.max_image_size
        self.min_image_size = args.min_image_size
        self.system_prompt = "You are a helpful assistant."
        tokenizer_with_prefix_space = processor.tokenizer.__class__.from_pretrained(
            args.model, add_prefix_space=True
        )

        def get_tokens_as_list(word_list):
            "Converts a sequence of words into a list of tokens"
            tokens_list = []
            for word in word_list:
                tokenized_word = tokenizer_with_prefix_space(
                    [word], add_special_tokens=False
                ).input_ids[0]
                tokens_list.append(tokenized_word)
            return tokens_list

        self.bad_words_ids = []
        if args.use_bad_words:
            self.bad_words_ids = get_tokens_as_list(
                word_list=[
                    "Wait",
                    "But wait",
                    "But the problem",
                ]
            )
        self.repetition_penalty = args.repetition_penalty
        self.do_resize = args.do_resize
        self.do_base64 = args.do_base64
        self.do_sample = args.do_sample
        self.use_tool = getattr(args, "use_tool", False)

    def get_base64(self, image):
        buffer = BytesIO()
        image.save(buffer, format="JPEG")
        base64_bytes = base64.b64encode(buffer.getvalue())
        base64_string = base64_bytes.decode("utf-8")
        return base64_string

    def decode(self, x):
        x = base64.b64decode(x)
        x = Image.open(BytesIO(x))
        if x.mode in ("RGBA", "P"):
            x = x.convert("RGB")
        return x

    def preprocess(self, image, question, response: Optional[str] = None):
        if self.do_resize:
            image = resize(image, self.max_image_size, self.min_image_size)

        image_rep = image
        if self.do_base64:
            base64_str = self.get_base64(image)
            image_new = self.decode(base64_str)
            base64_str = self.get_base64(image_new)
            base64_str = f"data:image/jpeg;base64,{base64_str}"

            image_rep = base64_str

        messages = [
            {"role": "system", "content": self.system_prompt},
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image_rep,
                    },
                    {
                        "type": "text",
                        "text": self.template.format(question),
                    },
                ],
            },
        ]

        if response is not None:
            messages.append(
                {
                    "role": "assistant",
                    "content": [
                        {
                            "type": "text",
                            "text": response,
                        },
                    ],
                }
            )
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            text = text.strip().removesuffix(self.processor.tokenizer.eos_token)
        else:
            text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        return text, messages

    def __call__(self, images, questions, responses: Optional[list] = None):
        if responses is None:
            responses = [None for _ in questions]
        texts, all_messages = zip(
            *[self.preprocess(*row) for row in zip(images, questions, responses)]
        )
        # image_inputs, _ = process_vision_info(all_messages)
        image_inputs = [
            process_vision_info(messages)[0][0] for messages in all_messages
        ]
        texts = list(texts)
        inputs = self.processor(
            text=texts,
            images=image_inputs,
            videos=None,
            padding=True,
            return_tensors="pt",
        ).to("cuda")

        # if self.do_sample:
        #     sampling_params = dict(
        #         do_sample=True,
        #         top_p=0.9,
        #         top_k=40,
        #         temperature=0.6,
        #         max_new_tokens=self.max_new_tokens,
        #         logits_processor=get_logits_processor(
        #             self.repetition_penalty,
        #             getattr(self.model.config, "copy_token_start", None),
        #         ),
        #     )
        # else:
        sampling_params = dict(
            do_sample=False,
            # temperature=0.0,
            max_new_tokens=self.max_new_tokens,
            repetition_penalty=self.repetition_penalty,
            # logits_processor=get_logits_processor(
            #     self.repetition_penalty,
            #     getattr(self.model.config, "copy_token_start", None),
            # ),
        )
        if len(self.bad_words_ids) > 0:
            sampling_params["bad_words_ids"] = self.bad_words_ids

        # generation check
        # with torch.no_grad():
        self.model.eval()
        assert hasattr(self.model, "rope_deltas")
        if self.use_tool:
            generated_ids = self.generate_tool(inputs, sampling_params)
        else:
            self.model.rope_deltas = None
            with torch.inference_mode():
                generated_ids = self.model.generate(
                    **inputs,
                    use_cache=True,
                    eos_token_id=self.processor.tokenizer.eos_token_id,
                    pad_token_id=self.processor.tokenizer.pad_token_id,
                    **sampling_params,
                )

        output_texts = self.processor.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        return [
            {
                "input": input_ids,
                "output": output_ids,
                "text": output_text,
            }
            for input_ids, output_ids, output_text in zip(
                inputs.input_ids.cpu(), generated_ids.cpu(), output_texts
            )
        ]

    def generate_tool(self, inputs, sampling_params):
        # detect bbox coord generation, halt, get patch, and re-run
        raise NotImplementedError()

    def generate(self, row):
        outputs = self([row["decoded_image"].convert("RGB")], [row["query"]])[0]
        outputs["id"] = row["id"]
        return outputs

    def generate_batch(self, batch):
        images = [row["decoded_image"].convert("RGB") for row in batch]
        texts = [row["query"] for row in batch]
        responses = None
        if "response" in batch[0]:
            responses = [row["response"] for row in batch]
        outputs = self(images, texts, responses)
        for i, output in enumerate(outputs):
            output["id"] = batch[i]["id"]
        return outputs


def resize(
    img: Image.Image, max_size: Optional[int] = None, min_size: Optional[int] = None
) -> Image.Image:
    width, height = img.size

    # Handle min_size constraint
    if min_size is not None and (width < min_size or height < min_size):
        if width < height:
            scale_factor = min_size / float(width)
        else:
            scale_factor = min_size / float(height)

        width = int(width * scale_factor)
        height = int(height * scale_factor)

    # Handle max_size constraint
    if max_size is not None and (width > max_size or height > max_size):
        if width > height:
            scale_factor = max_size / float(width)
        else:
            scale_factor = max_size / float(height)

        width = int(width * scale_factor)
        height = int(height * scale_factor)

    return img.resize((width, height))


def get_batches(data: Iterable[Any], batch_size: int) -> Generator[list, None, None]:
    """
    Splits the input iterable into batches of size `batch_size`.

    Args:
        data (Iterable): The dataset or input sequence to split.
        batch_size (int): Desired batch size.

    Yields:
        list: A batch of items.
    """
    batch = []
    for item in data:
        batch.append(item)
        if len(batch) == batch_size:
            yield batch
            batch = []
    if batch:  # yield last remaining items
        yield batch


def extract_answer(text):
    response = text.split("assistant\n")[1]
    if "boxed{" in response:
        final_answer = "}".join(response.split("boxed{")[-1].split("}")[:-1])
        return final_answer
    else:
        return "Unknown"


def infer(args, data, results, out_path, rerun_out_path=None):
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)

    args.do_sample = False
    if args.rerun:
        assert rerun_out_path is not None
        args.max_new_tokens = args.max_new_tokens * 2

        def check_no_answer(response: str) -> bool:
            answer = extract_answer(response)

            if answer == "Unknown":
                return True
            elif "cannot" in answer.lower():
                return True
            elif "not possible" in answer.lower():
                return True

            return False

        not_parsed = {k: check_no_answer(v) for k, v in results.items()}
        
        def get_response(text: str) -> str:
            response = text.split("assistant\n")[1]

            if "**Final Answer**" in response:
                response = response.split("**Final Answer**")[0]

            # get until last sentence by \n
            lines = response.strip().split("\n")
            response = "\n".join(lines[:-1])

            # force answer
            FORCE_ANSWER_PREFIX = "\nBut since the reasoning is already too long, I should yield my final answer now based on the best guess.\n\n**Final Answer**\\[ \\boxed{"
            response = response + FORCE_ANSWER_PREFIX
            return response

        data = [
            {**row, "response": get_response(results[row["id"]])}
            for row in data
            if row["id"] not in results or not_parsed[row["id"]]
        ]
        out_path = rerun_out_path

    num_gpus = torch.cuda.device_count()
    print(f"using {num_gpus} gpus")
    workers = [Inferer.remote(args) for _ in range(num_gpus)]
    pool = ActorPool(workers)

    def actor_call(actor, data):
        return actor.generate_batch.remote(data)

    print("running")
    lengths = []
    if args.rerun:
        inputs = [row for row in data]
    else:
        inputs = [row for row in data if (row["id"] not in results)]
    batches = list(get_batches(inputs, args.batch_size))
    with tqdm(desc="inference", total=len(batches)) as pbar:
        for batch_result in pool.map_unordered(actor_call, batches):
            for result in batch_result:
                results[result["id"]] = result["text"]

                length = len(result["output"]) - len(result["input"])
                lengths.append(length)

            with open(out_path, "w") as f:
                json.dump(results, f)

            pbar.update(1)
            pbar.set_description(
                f"inference: mean length {sum(lengths) / len(lengths)}, max_length {max(lengths)}"
            )
    print("done")
