import logging
from typing import Dict, Optional, List
import os

import json
import logging
import os
import queue
import sys

from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass, field
from pathlib import Path
from tqdm import tqdm
from typing import Union, List, Tuple, Any

import numpy as np
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.utils.data._utils.worker import ManagerWatchdog

from tqdm import tqdm
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from model.jina_reranker_m0 import JinaVLForRanking
from PIL import Image
from transformers.image_utils import load_image
logger = logging.getLogger(__name__)


class JinaRanker(nn.Module):
    def __init__(self, model_path):
        super().__init__()
        self.lm = JinaVLForRanking.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2")
        self.lm.eval()
        self.model_path = model_path
        self.LOGIT_BIAS = 2.65
        
    @torch.no_grad()
    def process(self, inputs, **kwargs):
        scores = self.lm.forward(**inputs).view(-1).cpu().float().numpy()
        scores = 1.0 / (1.0 + np.exp(-(scores - self.LOGIT_BIAS)))
        return scores

        
class TokenizeWorker:
    def __init__(self, tokenizer_path, max_length=10240, qsize=4, eod_id=100):
        self.tokenizer = AutoProcessor.from_pretrained(tokenizer_path, max_pixels=602112, min_pixels=3136, trust_remote_code=True)
        self.tokenizer.tokenizer.padding_side = 'left'
        self.eod_id = eod_id
        self.max_length = max_length
        self.qsize = qsize
        self.score_token_id = 100

    def tokenize(self, pairs: list):        
        max_length = self.max_length-1
        inputs = [pair for pair,doc_image,query_image in pairs]
        doc_images = load_images([doc_image for pair,doc_image,query_image in pairs])
        query_images = load_images([query_image for pair,doc_image,query_image in pairs])

        batch_images = []
        for doc_image, query_image in zip(doc_images, query_images):
            if doc_image is not None and query_image is not None:
                batch_images.append([doc_image, query_image])
            elif doc_image is not None:
                batch_images.append([doc_image])
            elif query_image is not None:
                batch_images.append([query_image])
            else:
                batch_images.append(None)
        if batch_images[0] is None:
            batch_images = None
        
        batch = self.tokenizer(
            text=inputs,
            images=batch_images,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        )

        batch_size = batch["input_ids"].size(0)
        batch["input_ids"] = torch.cat(
            [
                batch["input_ids"],
                torch.full((batch_size, 1), self.score_token_id, device=batch["input_ids"].device),
            ],
            dim=1,
        )
        batch["attention_mask"] = torch.cat(
            [
                batch["attention_mask"],
                torch.ones((batch_size, 1), device=batch["attention_mask"].device),
            ],
            dim=1,
        )
        batch.data = {k: v.pin_memory() for k, v in batch.data.items()}
        return batch


    def _tokenize_loop(self, input_queue, output_queue, device):
        keep_queue = queue.Queue(self.qsize + 1)

        while True:
            r = input_queue.get()
            if r is None:
                break

            n, batch = r
            inputs = self.tokenize(batch)
            inputs.to(device)
            output_queue.put((n, inputs))
            if keep_queue.full():
                k = keep_queue.get()
                del k
            keep_queue.put(inputs)
            del r, n, batch, inputs

        while not keep_queue.empty():
            i = keep_queue.get()
            del i
        return


def _encode_loop(model, input_queue, output_queue, device, qsize=4):
    model = model.to(device)
    watchdog = ManagerWatchdog()
    keep_queue = queue.Queue(qsize + 1)

    with torch.inference_mode():
        with torch.autocast(device_type=device.type, dtype=torch.float16):
            while watchdog.is_alive():
                r = input_queue.get()
                if r is None:
                    break
                n, inputs = r
                results = model.process(inputs=inputs)
                output_queue.put((n, results))
                if keep_queue.full():
                    i = keep_queue.get()
                    del i
                keep_queue.put(results)
                del r, n, inputs

    while not keep_queue.empty():
        i = keep_queue.get()
        del i
    del model, watchdog
    return


def _encode_loop_Jina(model, input_queue, output_queue, device, qsize=4):
    model = model.to(device)
    watchdog = ManagerWatchdog()
    keep_queue = queue.Queue(qsize + 1)

    with torch.inference_mode():
        with torch.autocast(device_type=device.type, dtype=torch.float16):
            while watchdog.is_alive():
                r = input_queue.get()
                if r is None:
                    break
                n, inputs = r
                results = model.process(inputs=inputs)
                output_queue.put((n, results))
                if keep_queue.full():
                    i = keep_queue.get()
                    del i
                keep_queue.put(results)
                del r, n, inputs

    while not keep_queue.empty():
        i = keep_queue.get()
        del i
    del model, watchdog
    return


def load_images(images, lazy_load: bool = True):
    pil_max_px = Image.MAX_IMAGE_PIXELS
    Image.MAX_IMAGE_PIXELS = None

    images_batch = []
    for image in images:
        if image is None:
            images_batch.append(None)
        elif isinstance(image, Image.Image):
            images_batch.append(image)
        else:
            pil_image = load_image(image)
            if lazy_load:
                images_batch.append(pil_image)
            else:
                images_batch.append(pil_image.copy())
                pil_image.close()
    Image.MAX_IMAGE_PIXELS = pil_max_px

    return images_batch


class JinaRerankerInferenceModel:
    def __init__(
        self,
        model_name_or_path: str,
        max_length: int = 10240,
        normalized: bool = False,
        qsize: int=4,
        instruction=None,
        format_type='chat',
        attn_type='causal',
        inference_type='yes_or_no'
    ) -> None:
        n_gpu = torch.cuda.device_count()
        self.max_length=max_length
        self.tokenizer = TokenizeWorker(model_name_or_path, max_length=max_length)
        self.instruction = instruction
        model = JinaRanker(model_name_or_path)
        self.model = model
        self.world_size = n_gpu
        self.mp_ctx = torch.multiprocessing.get_context('spawn')
        assert n_gpu > 0, 'woho, no no no!'
        logger.info(f"We have {n_gpu=}, good. Starting worker processes.")
        self._text_queues = [self.mp_ctx.Queue(qsize) for _ in range(n_gpu)]
        self._input_queues = [self.mp_ctx.Queue(qsize) for _ in range(n_gpu)]
        self._output_queues = [self.mp_ctx.Queue(qsize) for _ in range(n_gpu)]
        self._devices = list()
        self._tokenize_wokers = list()
        self._encode_workers = list()
        self.format_type = format_type
        for i, (tq, iq, oq) in enumerate(zip(self._text_queues, self._input_queues, self._output_queues)):
            device = torch.device(f'cuda:{i}')
            self._devices.append(device)
            w_t = self.mp_ctx.Process(
                target=self.tokenizer._tokenize_loop, name=f'tok_w_{i}', args=(tq, iq, device)
            )
            w_t.start()
            self._tokenize_wokers.append(w_t)  
            w_e = self.mp_ctx.Process(
                target=_encode_loop_Jina, name=f'enc_w_{i}', args=(model, iq, oq, device)
            )
            w_e.start()
            self._encode_workers.append(w_e)
            logger.info(f"GPU {i} worker initiated.")
        self.prompt = None
    """
    def __del__(self):
        self.stop()
    """
    def stop(self):
        for qs in (self._text_queues, self._input_queues):
            [q.put(None) for q in qs]
        for ws in (self._tokenize_wokers, self._encode_workers):
            [w.join() for w in ws]
            [w.close() for w in ws]
        for qs in (self._text_queues, self._input_queues, self._output_queues):
            [q.put(None) for q in qs]

    def _text_length(self, sent):
        return len(sent)
    
    def truncation_doc(self, doc, max_length):
        tokens = self.tokenizer.tokenizer.tokenizer.encode(doc)
        tokens = tokens[:max_length]
        doc = self.tokenizer.tokenizer.tokenizer.decode(tokens)
        return doc

    def format_content(self, text, image, prefix='Query:'):
        content = []
        if not text and not image:
            content = {'type': 'text', 'text': ""}
            return content
        content.append({'type': 'text', 'text': prefix})
        if image:
            content.append({'type': 'image', 'image': 'file://' + image})
        if text:
            content.append({'type': 'text', 'text': text})
        return content
    

    def formatting_prompts_func(
        self,
        query: str = None,
        doc: str = None,
        query_type: str = 'text',
        doc_type: str = 'text',
        prefix_str: str = '',
    ) -> str:
        # Format query part
        if query_type == 'image':
            query_part = "**Query**:\n<|vision_start|><|image_pad|><|vision_end|>"
        elif query_type == 'mixed':
            query_part = f"**Query**:\n<|vision_start|><|image_pad|><|vision_end|>{query}"
        else:
            query_part = f"**Query**:\n{query}"

        # Format content part
        if doc_type == 'image':
            doc_part = "**Document**:\n<|vision_start|><|image_pad|><|vision_end|>"
        elif doc_type == 'mixed':
            doc_part = f"**Document**:\n<|vision_start|><|image_pad|><|vision_end|>{doc}"
        else:
            doc_part = f"**Document**:\n{doc}"

        # Combine parts
        prompt = doc_part + '\n' + query_part

        # Add prefix if provided
        if prefix_str:
            prompt = prefix_str + '\n' + prompt

        return prompt

    
    def format_instruction(self, query_text, query_image_path, doc_text, doc_image_path):
        query_type = None
        doc_type = None
        assert query_text is not None or query_image_path is not None
        if query_text is None:
            query_type = 'image'
        elif query_image_path is None:
            query_type = 'text'
        else:
            query_type = 'mixed'

        if doc_text is None:
            doc_type = 'image'
        elif doc_image_path is None:
            doc_type = 'text'
        else:
            doc_type = 'mixed'
        
        if doc_text is not None:
            doc_text = self.truncation_doc(doc_text,self.max_length // 2)
        return self.formatting_prompts_func(query=query_text,query_type=query_type,doc=doc_text,doc_type=doc_type)


    def process(
        self,
        pairs: List,
        show_progress_bar: bool = True,
        convert_to_numpy: bool = True,
        batch_size: int = 8,
        return_dense=True,
        return_sparse=False,
        return_sparse_embedding=False,
        **kwargs
    ):
        
        doc_images = [doc_image for query_text, query_image, doc_text, doc_image in pairs]
        query_images = [query_image for query_text, query_image, doc_text, doc_image in pairs]
        pairs = [self.format_instruction(query_text, query_image, doc_text, doc_image) for query_text, query_image, doc_text, doc_image in pairs]
        pairs_sorted = [(pair,doc_image,query_image) for pair,doc_image,query_image in zip(pairs,doc_images,query_images)]
        batch_size, num_texts = batch_size, len(pairs)
        num_batches = num_texts // batch_size + int(num_texts % batch_size > 0)
        def _receive(oq, timeout=0.00125):
            try:
                n, scores = oq.get(timeout=timeout)
                result_dict[n] = scores
                pbar.update(1)
                del scores
            except queue.Empty:
                pass
        show_progress_bar = show_progress_bar and (num_batches > 10)
        pbar = tqdm(total=num_batches, disable=not show_progress_bar, mininterval=1, miniters=10)
        result_dict = dict()
        for n, i in enumerate(range(0, num_texts, batch_size)):
            batch = pairs_sorted[i: i + batch_size]
            rank = n % self.world_size
            self._text_queues[rank].put((n, batch))
            if n >= self.world_size:
                _receive(self._output_queues[rank])
        while len(result_dict) < num_batches:
            for oq in self._output_queues:
                _receive(oq)

        pbar.close()
        results = []
        for n in range(len(result_dict)):
            results.extend(result_dict[n])
        return results

