import logging
from typing import Dict, Optional, List
import os

import json
import logging
import os
import queue
import sys

from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass, field
from pathlib import Path
from tqdm import tqdm
from typing import Union, List, Tuple, Any

import numpy as np
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.utils.data._utils.worker import ManagerWatchdog

from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoModel, is_torch_npu_available
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, QWEN2_5_VL_INPUTS_DOCSTRING, Qwen2_5_VLCausalLMOutputWithPast
from qwen_vl_utils import process_vision_info


logger = logging.getLogger(__name__)



class GmeLLMRanker(nn.Module):
    def __init__(self, model_path, token_true_id, token_false_id, peft_path=None, max_length=8192, attn_type='causal', format_type='chat', inference_type='yes_or_no'):
        super().__init__()
        self.lm = Qwen2VLForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2")
        self.classifier = None
        if os.path.exists(os.path.join(model_path, 'classifier.bin')):
            lm_head_weights = self.lm.lm_head.weight.data
            weight_yes = lm_head_weights[token_true_id]
            D = weight_yes.size()[0]
            self.classifier = torch.nn.Linear(D, 2, bias=False)
            self.classifier.load_state_dict(torch.load(os.path.join(model_path, 'classifier.bin')))
            self.classifier = self.classifier.to(self.lm.device, dtype=self.lm.dtype)
        self.token_true_id = token_true_id
        self.token_false_id = token_false_id
        self.inference_type = inference_type
        self.lm.eval()
        
    @torch.no_grad()
    def process(self, inputs, **kwargs):
        if self.classifier is not None:
            outputs = self.lm(**inputs,output_hidden_states=True)
            last_hidden_state = outputs.hidden_states[-1][:,-1]
            logits = self.classifier(last_hidden_state)
            if self.inference_type == 'yes_or_no':
                true_vector = logits[:, 0]
                false_vector = logits[:, 1]
                batch_scores = torch.stack([false_vector, true_vector], dim=1)
                batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
                scores = batch_scores[:, 1].exp().tolist()
            elif self.inference_type == 'yes':
                true_vector = logits[:, 0]
                return true_vector.tolist()
        else:
            batch_scores = self.lm(**inputs).logits[:, -1, :]
            if self.inference_type == 'yes_or_no':
                true_vector = batch_scores[:, self.token_true_id]
                false_vector = batch_scores[:, self.token_false_id]
                batch_scores = torch.stack([false_vector, true_vector], dim=1)
                batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
                scores = batch_scores[:, 1].exp().tolist()
            elif self.inference_type == 'yes':
                true_vector = batch_scores[:, self.token_true_id]
                return true_vector.tolist()
        return scores

        
class TokenizeWorker:
    def __init__(self, tokenizer_path, max_length=3200, qsize=4, format_type='chat', eod_token='<|im_end|>'):        
        min_pixels = 256*28*28
        max_pixels = 1280*28*28
        self.tokenizer = AutoProcessor.from_pretrained(tokenizer_path, min_pixels=min_pixels, max_pixels=max_pixels)
        self.tokenizer.tokenizer.padding_side = 'left'
        self.eod_token = eod_token
        if self.eod_token is not None:
            self.eod_id = self.tokenizer.tokenizer.convert_tokens_to_ids(self.eod_token)
        else:
            self.eod_id = self.tokenizer.tokenizer.eos_token_id
        self.max_length = max_length
        self.qsize = 4
        self.format_type = format_type

    def tokenize(self, pairs: list):        
        max_length = self.max_length
        if self.format_type == 'chat':
            texts = [self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            ) for messages in pairs]
            image_inputs, video_inputs = process_vision_info(pairs)
            inputs = self.tokenizer(
                text=texts,
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
                truncation=False,
                max_length=self.max_length
            )
        if self.format_type == 'instruct':
            inputs = self.tokenizer(pairs, max_length=max_length - 1, return_attention_mask=False, padding=False, truncation=True)
            inputs['input_ids'] = [input_ids + [self.eod_id] for input_ids in inputs['input_ids']]
            inputs = self.tokenizer.pad(inputs, padding=True, return_attention_mask=True, return_tensors='pt')
        return inputs

    def _tokenize_loop(self, input_queue, output_queue, device):
        keep_queue = queue.Queue(self.qsize + 1)

        while True:
            r = input_queue.get()
            if r is None:
                break

            n, batch = r
            inputs = self.tokenize(batch)
            inputs.to(device)
            output_queue.put((n, inputs))
            if keep_queue.full():
                k = keep_queue.get()
                del k
            keep_queue.put(inputs)
            del r, n, batch, inputs

        while not keep_queue.empty():
            i = keep_queue.get()
            del i
        return


def _encode_loop(model, input_queue, output_queue, device, qsize=4):
    model = model.to(device)
    watchdog = ManagerWatchdog()
    keep_queue = queue.Queue(qsize + 1)

    with torch.inference_mode():
        with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
            while watchdog.is_alive():
                r = input_queue.get()
                if r is None:
                    break
                n, inputs = r
                results = model.process(inputs=inputs)
                output_queue.put((n, results))
                if keep_queue.full():
                    i = keep_queue.get()
                    del i
                keep_queue.put(results)
                del r, n, inputs

    while not keep_queue.empty():
        i = keep_queue.get()
        del i
    del model, watchdog
    return


class MonoQwenInferenceModel:
    def __init__(
        self,
        model_name_or_path: str,
        max_length: int = 2048,
        normalized: bool = False,
        qsize: int=4,
        instruction=None,
        format_type='chat',
        attn_type='causal',
        inference_type='yes_or_no'
    ) -> None:
        n_gpu = torch.cuda.device_count()
        self.max_length=max_length
        self.tokenizer = TokenizeWorker(model_name_or_path, max_length=max_length, format_type=format_type)
        token_false_id = self.tokenizer.tokenizer.tokenizer.get_vocab()["False"]
        token_true_id = self.tokenizer.tokenizer.tokenizer.get_vocab()["True"]
        self.instruction = instruction
        if self.instruction is None:
            self.instruction = "Retrieval document that can answer user's query"
        model = GmeLLMRanker(model_name_or_path, token_false_id=token_false_id, token_true_id=token_true_id, attn_type=attn_type, inference_type=inference_type)
        self.model = model
        self.world_size = n_gpu
        self.mp_ctx = torch.multiprocessing.get_context('spawn')
        assert n_gpu > 0, 'woho, no no no!'
        logger.info(f"We have {n_gpu=}, good. Starting worker processes.")
        self._text_queues = [self.mp_ctx.Queue(qsize) for _ in range(n_gpu)]
        self._input_queues = [self.mp_ctx.Queue(qsize) for _ in range(n_gpu)]
        self._output_queues = [self.mp_ctx.Queue(qsize) for _ in range(n_gpu)]
        self._devices = list()
        self._tokenize_wokers = list()
        self._encode_workers = list()
        self.format_type = format_type
        for i, (tq, iq, oq) in enumerate(zip(self._text_queues, self._input_queues, self._output_queues)):
            device = torch.device(f'cuda:{i}')
            self._devices.append(device)
            w_t = self.mp_ctx.Process(
                target=self.tokenizer._tokenize_loop, name=f'tok_w_{i}', args=(tq, iq, device)
            )
            w_t.start()
            self._tokenize_wokers.append(w_t)  
            w_e = self.mp_ctx.Process(
                target=_encode_loop, name=f'enc_w_{i}', args=(model, iq, oq, device)
            )
            w_e.start()
            self._encode_workers.append(w_e)
            logger.info(f"GPU {i} worker initiated.")
        self.prompt = None
    """
    def __del__(self):
        self.stop()
    """
    def stop(self):
        for qs in (self._text_queues, self._input_queues):
            [q.put(None) for q in qs]
        for ws in (self._tokenize_wokers, self._encode_workers):
            [w.join() for w in ws]
            [w.close() for w in ws]
        for qs in (self._text_queues, self._input_queues, self._output_queues):
            [q.put(None) for q in qs]

    def _text_length(self, sent):
        return len(sent)
    
    def truncation_doc(self, doc):
        tokens = self.tokenizer.tokenizer.encode(doc)
        tokens = tokens[:self.max_length // 2]
        doc = self.tokenizer.tokenizer.decode(tokens)
        return doc

    def format_content(self, text, image, prefix='Query:'):
        content = []
        if not text and not image:
            content = [{'type': 'text', 'text': ""}]
            return content
        if prefix is not None:
            content.append({'type': 'text', 'text': prefix})
        if image:
            content.append({'type': 'image', 'image': 'file://' + image})
        if text:
            content.append({'type': 'text', 'text': text})
        return content
    
    def format_instruction(self, query_text, query_image_path, doc_text, doc_image_path):
        inputs = []
        if isinstance(query_text, tuple):
            instruct, query_text = query_text
        else:
            instruct = self.instruction
        contents = []
        doc_content = self.format_content(doc_text, doc_image_path, prefix=None)
        contents.extend(doc_content)
        # query_prefix = "Assert the relevance of the previous image document to the following query, answer True or False. The query is: "
        query_prefix = "Assert the relevance of the previous document to the following query, answer True or False. The query is: "
        query_content = self.format_content(query_text, query_image_path, prefix=query_prefix)
        contents.extend(query_content)
        inputs.append({
            "role": "user",
            "content": contents
        })
        return inputs

    def process(
        self,
        pairs: List,
        show_progress_bar: bool = True,
        convert_to_numpy: bool = True,
        batch_size: int = 8,
        return_dense=True,
        return_sparse=False,
        return_sparse_embedding=False,
        **kwargs
    ):
        pairs = [self.format_instruction(query_text, query_image, doc_text, doc_image) for query_text, query_image, doc_text, doc_image in pairs]
        length_sorted_idx = np.argsort([-self._text_length(pair) for pair in pairs])
        pairs_sorted = [pairs[idx] for idx in length_sorted_idx]

        batch_size, num_texts = batch_size, len(pairs)
        num_batches = num_texts // batch_size + int(num_texts % batch_size > 0)
        def _receive(oq, timeout=0.00125):
            try:
                n, scores = oq.get(timeout=timeout)
                result_dict[n] = scores
                pbar.update(1)
                del scores
            except queue.Empty:
                pass
        show_progress_bar = show_progress_bar and (num_batches > 10)
        pbar = tqdm(total=num_batches, disable=not show_progress_bar, mininterval=1, miniters=10)
        result_dict = dict()
        for n, i in enumerate(range(0, num_texts, batch_size)):
            batch = pairs_sorted[i: i + batch_size]
            rank = n % self.world_size
            self._text_queues[rank].put((n, batch))
            if n >= self.world_size:
                _receive(self._output_queues[rank])
        while len(result_dict) < num_batches:
            for oq in self._output_queues:
                _receive(oq)

        pbar.close()
        results = []
        for n in range(len(result_dict)):
            results.extend(result_dict[n])
        results = [results[idx] for idx in np.argsort(length_sorted_idx)]
        return results
