from torch import nn
import torch
from any_precision import AnyPrecisionForCausalLM
from .zoo import MultiPrecModelWrapper
from .clarification import check_text_type, check_single_dewey, dewey_what_text_type
from auto_gptq import AutoGPTQForCausalLM
from transformers.generation.utils import ModelOutput
from transformers import AutoTokenizer
from collections import defaultdict
from .scheduler.act_scheduler import ActScheduler
from any_precision.modules.AnyPrecisionLinear import AnyPrecisionLinear
import json

def set_hook(layer_idx, acts):
    if layer_idx not in acts:
        acts[layer_idx] = {} 
    def hook(m, inp, oup):
        inp = inp[0]

        acts[layer_idx]['mean'] = inp.mean().cpu()
        acts[layer_idx]['var'] = inp.var(correction=1).cpu()
        acts[layer_idx]['acts'] = inp.detach().cpu()

    return hook
class PMPDForCausalLM(nn.Module):
  '''
  A model based on AnyPrecision that can be used for adaptive precision scheduling during inference.
  '''
  
  def __init__(self, model_path, scheduler=None, precisions=None, use_anyprec=True, quantize_model_cls=AutoGPTQForCausalLM, Solution=True, split = False, prune_func=None, reward_func=None):
    super(PMPDForCausalLM, self).__init__()
    if use_anyprec:
      self.model = AnyPrecisionForCausalLM.from_quantized(
        model_path,
        precisions=precisions
      )
      tokenizer_path = model_path
    else:
      # model path is a json file containing the paths to the quantized models
      self.model = MultiPrecModelWrapper.from_quantized(
        model_path, 
        precisions,
        quantize_model_cls=quantize_model_cls)
      tokenizer_path = self.model.config.get('model_path') if self.model.config else model_path
        
    self.model.eval().cuda()
    self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    self.scheduler = scheduler
    self.Solution = Solution
    self.prune_func = prune_func
    self.last_wait_position = 0  # Track the last position where wait words were found
    self.reward_func = reward_func
    self.split = split
    self.list_prob = []
    self.list_split_prob= []
    self.list_split_prob_15 = [] 
    self.one_question_split_prob_15 = [] 
  
  def forward(self, *args, **kwargs):
    return self.model.forward(*args, **kwargs)

  def reset(self, question, system_prompt):
    self.question = question
    self.system_prompt = system_prompt
    self.last_wait_position = 0
    self.thinking_steps = -1
    self.one_question_list_prob = []
    self.one_question_split_prob = []
    self.one_question_split_prob_15 = []
  
  def append_custom_content(self, input_ids, content, device, outputs=None, current_bit=None):
    content_tokens = self.tokenizer(content, return_tensors="pt", add_special_tokens=False)["input_ids"].to(device)
    past_key_values = outputs.past_key_values if outputs is not None else None
    output = self.model(content_tokens, precision=current_bit, use_cache=True, past_key_values=past_key_values)
    input_ids = torch.cat([input_ids, content_tokens], dim=-1)
    return input_ids, output, len(content_tokens)

  @torch.inference_mode()
  def generate(self, *args, **kwargs):
    assert self.scheduler is not None, "Scheduler is not provided."
    input_ids = kwargs['input_ids']
    input_ids = input_ids.clone()
    input_len = input_ids.shape[1]
    max_new_tokens = kwargs['max_new_tokens'] if 'max_new_tokens' in kwargs else (kwargs['max_length'] - input_len if 'max_length' in kwargs else 256)
    prefill_bit = kwargs['prefill_bit'] if 'prefill_bit' in kwargs else None
    past_key_values = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
    do_sample = kwargs.get('do_sample', False)
    temperature = kwargs.get('temperature', 1.0)
    top_p = kwargs.get('top_p', 1.0)
    top_k = kwargs.get('top_k', 0)
    min_p = kwargs.get('min_p', 0.0)
    descent_prompt = kwargs.get('descent_prompt', "</think>")
    scores = []
    calon_scores = []
    verion_scores = []
    calve_scores = []
    seek_scores = []
    problem_scores = []
    computation_scores = []
    verification_scores = []
    
    if prefill_bit is not None:
      print(f"Prefill with bit {prefill_bit} model")
      outputs = self.model(input_ids, precision=prefill_bit, past_key_values=past_key_values, use_cache=True)  
    else:
      max_precision = max(self.model.precisions)
      outputs = self.model(input_ids, precision=max_precision, past_key_values=past_key_values, use_cache=True)
      
    first_token_logits = outputs.logits[:, -1, :]
    first_token_id = first_token_logits.argmax(dim=-1)
    
    new_token = 0
    current_bit = max(self.scheduler.precisions)
    schedule_dict = {}
    precision_log = defaultdict(int)
    cot_precision = defaultdict(int)
    is_cot = True
    is_solution = False
    is_answer = False
    self.scheduler.reset()
    # Generation loop
    if isinstance(self.scheduler, ActScheduler):
      hooks = {}
    while True:
        is_split = False
        text_type = None
        dewey_text_type = None
        if do_sample:
            # Apply temperature scaling
            logits = outputs.logits[:, -1:] / temperature
            # print("logits",logits)
            
            # Apply TopP (nucleus) filtering
            if top_p < 1.0:
                # Sort logits in descending order
                sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
                # Calculate probabilities
                probs = torch.nn.functional.softmax(sorted_logits, dim=-1)
                # Calculate cumulative probabilities
                cumulative_probs = torch.cumsum(probs, dim=-1)
                # print(cumulative_probs)
                
                # Find the first position where cumulative probability exceeds top_p
                sorted_indices_to_remove = cumulative_probs > top_p
                # Shift the mask to keep the first token that exceeds the threshold
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                
                # Create a mask for the original logits
                indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
                indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
                
                # Set the logits of tokens to remove to negative infinity
                logits = torch.where(indices_to_remove, torch.tensor(float('-inf'), device=logits.device, dtype=logits.dtype), logits)
                # Print non-negative infinity logits
                # non_inf_mask = logits != float('-inf')
                # if non_inf_mask.any():
                #     print("After top-p filtering, non-inf logits:", logits[non_inf_mask])
            
            # Apply TopK filtering
            if top_k > 0:
                # Ensure top_k doesn't exceed vocabulary size
                vocab_size = logits.size(-1)
                k = min(top_k, vocab_size)
                if k > 0:
                    # Get top k values and indices
                    top_k_values, top_k_indices = torch.topk(logits, k, dim=-1)
                    # Create a mask for the top k values using scatter_ with a single operation
                    top_k_mask = torch.zeros_like(logits, dtype=torch.bool)
                    top_k_mask.scatter_(-1, top_k_indices, torch.ones_like(top_k_indices, dtype=torch.bool))
                    # Set non-top-k values to negative infinity using the same device as logits
                    logits = torch.where(top_k_mask, logits, torch.tensor(float('-inf'), device=logits.device, dtype=logits.dtype))
                    # Print non-negative infinity logits
                    # non_inf_mask = logits != float('-inf')
                    # if non_inf_mask.any():
                    #     print("After top-k filtering, non-inf logits:", logits[non_inf_mask])
            
            # Calculate final probabilities
            probs = torch.nn.functional.softmax(logits, dim=-1)
            # Ensure valid probability distribution
            # print("probs",probs)
            probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
            # print("probs",probs)
            # Print non-zero probabilities
            # non_zero_mask = probs != 0
            # if non_zero_mask.any():
            #     non_zero_probs = probs[non_zero_mask]
            #     non_zero_indices = torch.nonzero(non_zero_mask)[:, -1]
            #     print("Non-zero probabilities and their indices:")
            #     for idx, prob in zip(non_zero_indices, non_zero_probs):
            #         print(f"Index {idx.item()}: {prob.item()}")
            # else:
            #     print("WARNING: All probabilities are zero!")
            probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8)
            # Sample from the filtered distribution
            input_id = torch.multinomial(probs.squeeze(), 1).unsqueeze(0)
            
            # selected_prob = probs.gather(-1, input_id)[0, 0, 0].item()
            selected_prob = probs.gather(-1, input_id.unsqueeze(-1))[0, 0, 0].item()
            #print(f"Selected token ID: {input_id[0].item()}, Probability: {selected_prob:.6f}")
            if is_cot:
                self.one_question_list_prob.append(selected_prob)
        else:
            input_id = outputs.logits[:, -1:].argmax(dim=-1)
            if is_cot:
                logits = outputs.logits[:, -1:]
                probs = torch.nn.functional.softmax(logits, dim=-1)
                selected_prob = probs.gather(-1, input_id)[0, 0, 0].item()
                self.one_question_list_prob.append(selected_prob)
        
        # single_word_type = check_single_word(self.tokenizer.decode(input_id[0]))
        single_word_type = check_single_dewey(self.tokenizer.decode(input_id[0]))

        if isinstance(self.scheduler, ActScheduler):
          acts = {}
          linear_layer_idx = 0
          for m in self.model.modules():
              if isinstance(m, (nn.Linear, AnyPrecisionLinear)):
                  hooks[linear_layer_idx] = m.register_forward_hook(set_hook(linear_layer_idx, acts))
                  linear_layer_idx += 1
        outputs = self.model(input_id, precision=current_bit, use_cache=True, past_key_values=outputs.past_key_values)
        if isinstance(self.scheduler, ActScheduler):
          for h in hooks.values():
              h.remove()
        input_ids = torch.cat([input_ids, input_id], dim=-1)
        new_token += 1

        # End generation if the EOS token is generated for any sequence in the batch
        if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break
        if new_token > max_new_tokens:
            break
        
        token_end_token = self.tokenizer.encode("</think>")
        # print(token_end_token[-1])
        # print(token_end_token)
        # Safely check the sequence of tokens
        if input_id[0] == token_end_token[-1]:
            is_cot = False
            is_solution = True
            # print("is_cot",is_cot)
            # print("is_solution",is_solution)
            # Add "The answer is:" to guide direct answer generation
            if not self.Solution:
              answer_prompt = "The answer is:"
              input_ids, outputs, new_tokens = self.append_custom_content(
                  input_ids, answer_prompt, input_ids.device,
                  outputs=outputs, current_bit=current_bit
              )
              new_token += new_tokens
              continue
        # print("split is", self.split)
        if is_cot:
            cur_token = self.tokenizer.decode(input_id[0])
            has_wait_words = any(word in cur_token for word in ["wait", "Wait", "Alternatively"])
            if has_wait_words:
                self.one_question_split_prob.append(selected_prob)
                if len(self.one_question_split_prob_15) < 15:
                    self.one_question_split_prob_15.append(selected_prob)
        
        if self.split and is_cot:
          cur_token = self.tokenizer.decode(input_id[0])
          has_wait_words = any(word in cur_token for word in ["wait", "Wait", "Alternatively"])
          if has_wait_words:
            is_split = True
            
            tokens_between = input_ids[0, input_len + self.last_wait_position - 1 : input_len + new_token - 1]
            decoded_text = self.tokenizer.decode(tokens_between)
            self.last_wait_position = new_token
            score = self.reward_func(self.system_prompt, self.question, decoded_text) if self.reward_func is not None else 0.0
            calon, verion, calve, seek = check_text_type(decoded_text)
            text_type = "calon" if calon else "verion" if verion else "calve" if calve else "seek"
            schedule_dict['text_type'] = text_type
            schedule_dict['dewey_text_type'] = dewey_text_type
            if text_type == "calon":
              calon_scores.append(score)
            elif text_type == "verion":
              verion_scores.append(score)
            elif text_type == "calve":
              calve_scores.append(score)
            elif text_type == "seek":
              seek_scores.append(score)

            dewey_text_type = dewey_what_text_type(decoded_text)
            if dewey_text_type == "problem_formulation":
              problem_scores.append(score)
            elif dewey_text_type == "computation":
              computation_scores.append(score)
            elif dewey_text_type == "verification":
              verification_scores.append(score)
            self.thinking_steps += 1
            
            # Check if we should force terminate
            if self.prune_func is not None:
              if self.thinking_steps == 19 or score > self.prune_func[self.thinking_steps]:
                print("force terminate")
                if score > self.prune_func[self.thinking_steps]:
                  print(f"Score: {score}, prune_func[thinking_steps]: {self.prune_func[self.thinking_steps]}")
                else:
                  print("thinking steps is 19")
                input_ids, outputs, _ = self.append_custom_content(
                    input_ids, descent_prompt, input_ids.device,
                    outputs=outputs, current_bit=current_bit
                )
                new_token += 1
                is_cot = False
                is_solution = True
                continue
            else:
              scores.append(score)

        # Schedule the next precision
        schedule_dict['index'] = new_token
        schedule_dict['past_key_values'] = outputs.past_key_values[-1]
        schedule_dict['logits'] = outputs.logits
        schedule_dict['precision'] = current_bit
        schedule_dict['cur_id'] = input_id[0]
        schedule_dict['cur_phase'] = "cot" if is_cot else "solution" if is_solution else "answer"
        schedule_dict['scores'] = scores
        schedule_dict['is_split'] = is_split
        schedule_dict['text_type'] = text_type
        schedule_dict['dewey_text_type'] = dewey_text_type
        schedule_dict['single_word_type'] = single_word_type
        schedule_dict['one_question_split_prob'] = self.one_question_split_prob

        if isinstance(self.scheduler, ActScheduler):
          if self.scheduler.all_layers:
            _temp = []
            for layer_acts in acts.values():
              _temp.append(layer_acts['mean'])
              _temp.append(layer_acts['var'])
            schedule_dict['acts'] = torch.stack(_temp)
          else:            
            schedule_dict['acts'] = acts[list(acts.keys())[-1]]['acts']
        schedule_dict['calon_scores'] = calon_scores
        schedule_dict['verion_scores'] = verion_scores
        schedule_dict['calve_scores'] = calve_scores
        schedule_dict['seek_scores'] = seek_scores
        schedule_dict['problem_scores'] = problem_scores
        schedule_dict['computation_scores'] = computation_scores
        schedule_dict['verification_scores'] = verification_scores
        prev_bit = current_bit
        current_bit = self.scheduler.schedule(**schedule_dict)
        if is_cot:
           cot_precision[current_bit] += 1
        precision_log[current_bit] += 1
        if prev_bit != current_bit:
          scores = []
          dewey_text_type = None
          calon_scores = []
          verion_scores = []
          calve_scores = []
          seek_scores = []
          problem_scores = []
          computation_scores = []
          verification_scores = []

        # print("current_token",self.tokenizer.decode(input_id[0]))
        # print("current_bit",current_bit)

        if current_bit == 0 and self.split and is_cot:
          current_bit = kwargs['sol_precision'] 
          input_ids, outputs, _ = self.append_custom_content(
              input_ids, descent_prompt, input_ids.device,
              outputs=outputs, current_bit=current_bit
          )
          new_token += 1
          is_cot = False
          is_solution = True

           
    self.scheduler.reset()
    
    if len(self.one_question_list_prob) > 0:
        avg_prob = sum(self.one_question_list_prob) / len(self.one_question_list_prob)
        self.list_prob.append(avg_prob)
    
    if len(self.one_question_split_prob) > 0:
        # five
        first_five = self.one_question_split_prob[:5]
        avg_split_prob = sum(first_five) / len(first_five)
        self.list_split_prob.append(avg_split_prob)
    
    if len(self.one_question_split_prob_15) < 15:
        self.one_question_split_prob_15.extend([0.0] * (15 - len(self.one_question_split_prob_15)))
    
    self.list_split_prob_15.append(self.one_question_split_prob_15)
    
    return ModelOutput(
        input_ids=input_ids,
        new_token=new_token,
        precision_log=precision_log,
        cot_precision=cot_precision,
        past_key_values=outputs.past_key_values,
        one_question_list_prob=self.one_question_list_prob,
        list_prob=self.list_prob,
        list_split_prob=self.list_split_prob,
        one_question_split_prob_15=self.one_question_split_prob_15
    )
    
  
  # for gradio demo
  @torch.inference_mode()
  def pmpd_generate(self, input_ids, max_new_tokens=256):
    assert self.scheduler is not None, "Scheduler is not provided."
    input_ids = input_ids.clone()
    input_len = input_ids.shape[1]
    prefill_bit = None
    past_key_values = None
    
    if prefill_bit is not None:
      print(f"Prefill with bit {prefill_bit} model")
      outputs = self.model(input_ids, precision=prefill_bit, past_key_values=past_key_values, use_cache=True)  
    else:
      max_precision = max(self.model.precisions)
      print(f"Prefill with the highest bit {max_precision} model")
      outputs = self.model(input_ids, precision=max_precision, past_key_values=past_key_values, use_cache=True)
    new_token = 0
    current_bit = max(self.scheduler.precisions)
    schedule_dict = {}
    precision_log = defaultdict(int)
    self.scheduler.reset()
    # Generation loop
    while True:
        input_id = outputs.logits[:, -1:].argmax(dim=-1)
        outputs = self.model(input_id, precision=current_bit, use_cache=True, past_key_values=outputs.past_key_values)
        input_ids = torch.cat([input_ids, input_id], dim=-1)
        new_token += 1
        
        yield input_ids, current_bit

        # End generation if the EOS token is generated for any sequence in the batch
        if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break
        if new_token > max_new_tokens:
            break
        
        # Schedule the next precision
        schedule_dict['index'] = new_token
        schedule_dict['past_key_values'] = outputs.past_key_values[-1]
        schedule_dict['logits'] = outputs.logits
        schedule_dict['precision'] = current_bit
        current_bit = self.scheduler.schedule(**schedule_dict)
        precision_log[current_bit] += 1
        
    self.scheduler.reset()
    
  # for gradio demo
  @torch.inference_mode()
  def naive_generate(self, input_ids, precision, max_new_tokens=256):
    print(f"Using precision {precision}")
    input_ids = input_ids.clone()
    input_len = input_ids.shape[1]
    past_key_values = None
    outputs = self.model(input_ids, precision=precision, past_key_values=past_key_values, use_cache=True)
    new_token = 0
    # Generation loop
    while True:
        input_id = outputs.logits[:, -1:].argmax(dim=-1)
        outputs = self.model(input_id, precision=precision, use_cache=True, past_key_values=outputs.past_key_values)
        input_ids = torch.cat([input_ids, input_id], dim=-1)
        new_token += 1
        
        yield input_ids

        # End generation if the EOS token is generated for any sequence in the batch
        if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break
        if new_token > max_new_tokens:
            break
