import logging
from typing import Dict, Optional, List
import os

import heapq
import json
import logging
import os
import queue
import sys
import time
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass, field
from pathlib import Path
from tqdm import tqdm

import torch
from torch.utils.data._utils.worker import ManagerWatchdog
import numpy as np
import torch.distributed as dist
from torch import nn, Tensor
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from transformers.file_utils import ModelOutput
from utils import load_model
from PIL import Image
import pandas as pd

logger = logging.getLogger(__name__)
torch.cuda.memory._set_allocator_settings('expandable_segments:False')
class PreprocessorWorker:
    def __init__(self, preprocessor, qsize=4):
        self.preprocessor = preprocessor
        self.qsize = 8
        
    def preprocess(self, texts: list, images: list, is_query=True, instruction=None):
        inputs = self.preprocessor(texts, images, is_query, instruction)
        return inputs

    def _preprocess_loop(self, input_queue, output_queue, device):
        keep_queue = queue.Queue(self.qsize + 1)

        while True:
            r = input_queue.get()
            if r is None:
                break

            n, batch_text, batch_image, is_query, instruction = r
            inputs = self.preprocess(texts=batch_text, images=batch_image, is_query=is_query, instruction=instruction)
            for key in inputs:
                if inputs[key] is None or isinstance(inputs[key], list):
                    continue
                if isinstance(inputs[key], dict):
                    for k, v in inputs[key].items():
                        if isinstance(v, torch.Tensor):
                            inputs[key][k] = v.to(device)
                else:
                    inputs[key] = inputs[key].to(device)
            output_queue.put((n, inputs))
            if keep_queue.full():
                try:
                    k = keep_queue.get()
                    del k
                except queue.Empty:
                    pass
            keep_queue.put(inputs)
            del r, n, batch_text, batch_image, is_query, instruction

        while not keep_queue.empty():
            i = keep_queue.get()
            del i
        return
def _encode_loop(model, input_queue, output_queue, device, dtype=torch.float16, qsize=8):
    model = model.to(device)
    watchdog = ManagerWatchdog()
    keep_queue = queue.Queue(qsize + 1)

    with torch.inference_mode():
        with torch.autocast(device_type=device.type, dtype=dtype):
            while watchdog.is_alive():
                r = input_queue.get()
                if r is None:
                    break
                n, inputs = r
                results = model.encode(image=inputs['image'], text=inputs['text'])
                output_queue.put((n, results))
                if keep_queue.full():
                    try:
                        i = keep_queue.get()
                        del i
                    except queue.Empty:
                        pass
                keep_queue.put(results)
                del r, n, inputs

    while not keep_queue.empty():
        i = keep_queue.get()
        del i
    del model, watchdog
    return


class InferenceModelMultiGPU:
    def __init__(
        self,
        config: str,
    ) -> None:
        n_gpu = torch.cuda.device_count()
        # model, tokenizer, image_preprocessor = load_model(config)
        model, preprocessor = load_model(config)
        dtype = torch.float16
        if 'dtype' in config:
            dtype = config['dtype']
            if dtype == 'fp16':
                dtype = torch.float16
            elif dtype == 'bf16':
                dtype = torch.bfloat16
            elif dtype == 'fp32':
                dtype = torch.float32
            else:
                dtype = torch.float16
        self.model = model
        self.instruction = None
        max_length = config.get('max_length', 512)
        qsize = config.get('qsize', 8)
        self.preprocessor = PreprocessorWorker(preprocessor,  qsize=qsize)
        self.world_size = n_gpu
        self.mp_ctx = torch.multiprocessing.get_context('spawn')
        assert n_gpu > 0, 'woho, no no no!'
        self._text_queues = [self.mp_ctx.Queue(qsize) for _ in range(n_gpu)]
        self._input_queues = [self.mp_ctx.Queue(qsize) for _ in range(n_gpu)]
        self._output_queues = [self.mp_ctx.Queue(qsize) for _ in range(n_gpu)]
        self._devices = list()
        self._preprocessor_wokers = list()
        self._encode_workers = list()
        for i, (tq, iq, oq) in enumerate(zip(self._text_queues, self._input_queues, self._output_queues)):
            device = torch.device(f'cuda:{i}')
            device = torch.device(f'cuda:{i}')
            self._devices.append(device)
            w_t = self.mp_ctx.Process(
                target=self.preprocessor._preprocess_loop, name=f'tok_w_{i}', args=(tq, iq, device)
            )
            w_t.start()
            self._preprocessor_wokers.append(w_t)  
            w_e = self.mp_ctx.Process(
                target=_encode_loop, name=f'enc_w_{i}', args=(model, iq, oq, device, dtype)
            )
            w_e.start()
            self._encode_workers.append(w_e)
            logger.info(f"GPU {i} worker initiated.")

    def stop(self):
        for qs in (self._text_queues, self._input_queues):
            [q.put(None) for q in qs]
        for ws in (self._preprocessor_wokers, self._encode_workers):
            [w.join() for w in ws]
            [w.close() for w in ws]
        for qs in (self._text_queues, self._input_queues, self._output_queues):
            [q.put(None) for q in qs]
    

    def _text_length(self, item):
        if 'length' in item:
            return item['length']
        else:
            return 0

    def encode(
        self,
        corpus,
        show_progress_bar: bool = True,
        batch_size: int = 512,
        is_query=True,
        instruction=None,
        **kwargs
    ):
        length_sorted_idx = np.argsort([-self._text_length(item) for item in corpus]).tolist()
        batch_size, total_number = batch_size, len(corpus)
        num_batches = total_number // batch_size + int(total_number % batch_size > 0)
        def _receive(oq, timeout=0.00125):
            try:
                n, embed = oq.get(timeout=timeout)
                result_dict[n] = embed
                pbar.update(1)
                del embed
            except queue.Empty:
                pass

        pbar = tqdm(total=num_batches, disable=not show_progress_bar, mininterval=1, miniters=10)
        result_dict = dict()
        for n, i in enumerate(range(0, total_number, batch_size)):
            batch = [corpus[idx] for idx in length_sorted_idx[i: i + batch_size]]
            batch_text = [item['text'] for item in batch]
            batch_image = [item['image'] for item in batch]
            if instruction is None:
                instruction = self.instruction
            rank = n % self.world_size
            self._text_queues[rank].put((n, batch_text, batch_image, is_query, instruction))
            if n >= self.world_size:
                _receive(self._output_queues[rank])
        while len(result_dict) < num_batches:
            for oq in self._output_queues:
                _receive(oq)

        results = [result_dict[n] for n in range(len(result_dict))]

        pbar.close()
        final_results = {}
        all_dense_vecs = torch.cat([ele.cpu() for ele in results]).numpy().tolist()
        all_dense_vecs = np.array([all_dense_vecs[idx] for idx in np.argsort(length_sorted_idx)])
        final_results['dense_embeddings'] = all_dense_vecs
        return final_results
