from utils import Model_Factory,ModelConfig
import torch
import concurrent.futures
import time
from nltk.corpus import words
import torch.nn.functional as F
import numpy as np


class ensemble_sevice(object):
    def __init__(self,inference_model_lists,verify_model_lists,max_generate_tokens,each_step_max_generate_token):

       
        self.inference_model_lists = inference_model_lists
        self.verify_model_lists = verify_model_lists
        

        self.max_generate_tokens = max_generate_tokens
        self.each_step_max_generate_token = each_step_max_generate_token
      
        self.org_model_list = [model.model_type  for model in self.inference_model_lists]
        self.model_id_dic = {self.org_model_list[i]: i for i in range(len(self.org_model_list))}

    def generate(self,input_text,short_decoder=True,Output_stream=False,Exit_Mechanism=False):
        max_tokens = self.max_generate_tokens
        eos_flag = True
        best_response = ""
        best_new_generation = input_text
        Stop_token = ["</s>","<|eot_id|>","<|im_end|>","<eos>","<end_of_turn>\n<eos>","<|user|>","<|endoftext|>"]
        Fast_decoder = short_decoder
        
        model_list = [model.model_type  for model in self.inference_model_lists]
        
        self.exit_threshold = 0.5*(1/(len(model_list)))

        model_container = [ i for i in self.inference_model_lists]

        model_hit_times = {}
        for key in model_list:
            model_hit_times[key]=0

        self.Clear_Cache()
        sum_all_scores = torch.tensor(0)

        turns_num = 0
        call_model_nums = 0
        dynamic_temperature = False
        hit_list=[]

        prevent_incomplete = False
        while eos_flag and len(best_response.split(" "))<max_tokens and self.cn_str_count(best_response)<max_tokens:
            
            turns_num += 1
            call_model_nums += len(model_list)


            new_generations,spaces = self.Parallel_Inference(best_new_generation,model_container)
          
          

            if prevent_incomplete:

                new_generation_list = [' '.join(new_generations[model_name].split(" ")[:-1]) for model_name in model_list]
            else:

                new_generation_list = [new_generations[model_name] for model_name in model_list]

            spaces_list = [spaces[model_name] for model_name in model_list]
            
            

 
            results_score = self.Parallel_verify(new_generation_list,model_container)
            
           
            Eliminate_self_judgment = True
            if Eliminate_self_judgment:
                for m in results_score:
                    results_score[m][model_list.index(m)]=0
 
            sum_socre = sum([results_score[model_name].to(self.verify_model_lists[0].device) for model_name in results_score ]) 



            #get result rank list
            _, indices = torch.sort(sum_socre, descending=True)

            #add to hit list
            hit_model_id = indices[0].item()

            hit_list.append(self.model_id_dic[model_list[hit_model_id]])
            

            #sum each model hit times
            model_hit_times[model_list[hit_model_id]]+=1

            if Exit_Mechanism:
                if not sum_all_scores.shape:
                    sum_all_scores = sum_socre
                else:
                    sum_all_scores += sum_socre

              
                
                #get temperature
                if hit_list:
                    temperature = self.calculate_dynamic_temperature(hit_list,model_list)
                else:
                    temperature = 0

                if dynamic_temperature:
                    Exit_Model_indices = (F.softmax(sum_all_scores/(1+temperature), dim=0) < self.exit_threshold).nonzero(as_tuple=True)[0].tolist()
                else:
                    Exit_Model_indices = (F.softmax(sum_all_scores, dim=0) < self.exit_threshold).nonzero(as_tuple=True)[0].tolist()
                        
                #update model list 
                model_container = [model_container[i] for i in range(len(model_container)) if i not in Exit_Model_indices]

                model_list = [model.model_type  for model in model_container]

                #update sum score

                indices_to_keep = [i for i in range(sum_all_scores.size(0)) if i not in Exit_Model_indices]

                sum_all_scores = sum_all_scores[indices_to_keep]
              
      



            if Fast_decoder:
                return_text = self.Find_eos_token(new_generation_list=new_generation_list,indices=indices,Stop_token=Stop_token)
                if return_text:
                    best_new_generation = return_text
                    best_response = best_response+best_new_generation
                    eos_flag = False
                    if Output_stream:
                        print(best_new_generation +"\n",end='')
                    continue
                
            
            best_new_generation = new_generation_list[hit_model_id]

    
            best_response = best_response+best_new_generation
            
            if not Fast_decoder:
                for stop in Stop_token:
                    if stop in best_new_generation:
                        eos_flag =False

            if Output_stream:
                print(best_new_generation,end='')
                
        self.Clear_Cache()
        if Exit_Mechanism:
            return best_response
        else:
            return best_response,model_hit_times

    

    def Parallel_Inference(self,input_text,inference_models):
        results = {}
        spaces ={}
        with concurrent.futures.ThreadPoolExecutor(max_workers=len(inference_models)) as executor:
            # Submit all tasks and record future objects

            futures = []

            for model in inference_models:
                futures.append(executor.submit(model.Speculative_generate, input_text,self.each_step_max_generate_token))
            
            # Wait for all tasks to complete and get the results
            for future in concurrent.futures.as_completed(futures):
                
                result,model_name,if_add_space = future.result()
                results[model_name] = result
                spaces[model_name] = if_add_space
        

        # print("All model completed this step generation.")
        return results,spaces
    


    def Parallel_verify(self,input_text_list,verify_models):
        results = {}
        with concurrent.futures.ThreadPoolExecutor(max_workers=len(verify_models)) as executor:
            # Submit all tasks and record future objects
            futures = []
            for model in verify_models:
                futures.append(executor.submit(model.model_verify, input_text_list))

            # Wait for all tasks to complete and get the results
            for future in concurrent.futures.as_completed(futures):
                
                scores,model_name = future.result()
                results[model_name] = scores
        
        # print("All model completed this step verify.")
        return results
    

    def Clear_Cache(self):
        for model in set(self.inference_model_lists+self.verify_model_lists):
            model.Clear_Cache()

    def Find_eos_token(self,new_generation_list=None,indices=None,Stop_token=None):
        for indice in indices:
            for stop in Stop_token:
                if stop in new_generation_list[indice.item()]:
                    
                    return new_generation_list[indice.item()]
        return None

    def if_a_word(self,word):
        if word in words.words():
            return True
        else:
            return False
    
    def cn_str_count(self,str):
        num = 0 
        for s in str:
            
            if '\u4e00' <= s <= '\u9fff':
                num+=1
        return num
    
    def entropy(self,probability_list):

        # Make sure the probability is legitimate
        probabilities = [ i for i in probability_list if i>0]

        
        probabilities = np.array(probabilities)

        if np.any(probabilities < 0) or np.any(probabilities > 1):
            raise ValueError("Probabilities should be between 0 and 1.")
     
        if not np.isclose(np.sum(probabilities), 1):
            raise ValueError("Probabilities should sum to 1.")
        
        # return information entropy
        return -np.sum(probabilities * np.log2(probabilities))


    def calculate_dynamic_temperature(self,hit_list,model_list):
        
        probabilities = self.calculate_rewards(hit_list,model_list)

        entropy_value = self.entropy(probabilities)

        length = len(probabilities)

        max_entropy = self.entropy([1/length for i in range(length)])


        temperature = (entropy_value/max_entropy)

        if temperature<=0.5:
            temperature=0
        else:
            # temperature = (temperature-0.5)*2
            temperature = temperature-0.5
        return temperature

   
    def calculate_rewards(self,hit_list,model_list):

        rewards = {}

        rewards = {str(i): 0. for i in range(len(self.model_id_dic))}
        
        length = len(hit_list)

        size = 4

        model_id_list = [str(self.model_id_dic[i])  for i in model_list]

        # reward
        for i in range(length):
            hit_system = str(hit_list[i])
            
            if hit_system not in model_id_list:
                continue

            if length - size <= i < length:
                rewards[hit_system] += 1
                
          
                rewards[hit_system] += 3/4

            elif length - 3*size <= i < length - 2*size:
                rewards[hit_system] += 2/4

      
            else:
                rewards[hit_system] += 1/4
        

        total_sum_rewards = sum(rewards.values())

        probability_distribution = [rewards[i]/total_sum_rewards  for i in rewards  if i in model_id_list]
       
        return probability_distribution
    

