from core.actor import *
from tqdm import tqdm
import itertools
import ray
import math


class Router:

    """
        Router 需要控制 Batchsize 和 资源分配
    """

    def __init__(self, args):

        self.args = args
        self.workers = self.args.workers
        self.tensor_parallel = self.args.tensor_parallel
        self.batch_size = self.args.batch_size

        self.inference_model_modality = self.args.inference_model_modality
        self.refinement_model_modality = self.args.refinement_model_modality
        self.ranking_model_modality = self.args.ranking_model_modality

        self.llm_workers = []
        self.handlers = []

        self._init_inf = True
        self._init_ref = True
        self._init_rnk = True

    def init_ray_if_needed(self, **kwargs):

        if not ray.is_initialized():
            print("[INFO] Ray not initialized. Initializing now...")
            ray.init(**kwargs)
        else:
            print("[INFO] Ray already initialized. Skipping ray.init().")

    def init_model(self, model_modality, model_path):
        
        self.init_ray_if_needed(address="auto", ignore_reinit_error=True)
        if model_modality == "language":
            core = LanguageVLLM
        elif model_modality == "vision":
            core = VisionVLLM
        elif model_modality == "omni":
            core = OmniVLLM
        else:
            raise Exception("Not support this model_modality..")
        
        del self.llm_workers
        self.llm_workers = []
        for actor_idx in range(self.workers):
            gpu_core = core.options(num_gpus=self.tensor_parallel)
            self.llm_workers.append(gpu_core.remote(self.args, actor_idx, model_path))

    def init_refinement_model(self):
        if self._init_ref:
            model_path = self.args.refinement_model
            self.init_model(self.refinement_model_modality, model_path)
            self._init_ref = False

    def init_inference_model(self):
        if self._init_inf:
            model_path = self.args.inference_model
            self.init_model(self.inference_model_modality, model_path)
            self._init_inf = False

    def init_ranking_model(self):
        
        self.init_ray_if_needed(address="auto", ignore_reinit_error=True)
        if self.ranking_model_modality == "language":
            core = ScalarRewardWorker
        elif self.ranking_model_modality == "vision":
            core = ScalarVisionRewardWorker
        else:
            raise Exception("Not support this model_modality..")
        
        # 这里要杀掉之前的模型, by 
        del self.llm_workers
        self.llm_workers = []
        for actor_idx in range(self.workers):
            gpu_core = core.options(num_gpus=self.tensor_parallel)
            self.llm_workers.append(gpu_core.remote(self.args, actor_idx))

    def run(self, dataset_iter, total=None):
        
        pending = []
        workers = list(self.llm_workers)
        
        for worker in workers:
            batch_dataset = list(itertools.islice(dataset_iter, self.batch_size))
            if not batch_dataset:
                break
            fut = worker.generate.remote(batch_dataset)
            pending.append((worker, fut))

        pbar = tqdm(total=total, desc="Router Processed Prompts") if total else None

        while pending:

            future_list = [f for (_, f) in pending]
            done, _ = ray.wait(future_list, num_returns=1)

            finished_items = []
            for (worker, fut) in pending:
                if fut in done:
                    finished_items.append((worker, fut))
            
            for item in finished_items:
                pending.remove(item)
            
            for worker, fut in finished_items:
                results = ray.get(fut)
                self.handle_results(results)
                if pbar:
                    pbar.update(len(results))

                batch_dataset = list(itertools.islice(dataset_iter, self.batch_size))
                if batch_dataset:
                    new_fut = worker.generate.remote(batch_dataset)
                    pending.append((worker, new_fut))

        if pbar:
            pbar.close()

    def reset_handler(self):
        self.handlers = []

    def add_handler(self, handler):
        self.handlers.append(handler)

    def handle_results(self, results):
        for handler in self.handlers:
            handler(results)
