import math
import random

from openai import OpenAI
import json
import base64
from tqdm import tqdm
import os
import math
import argparse
from utils import *
from prompt import *
from model import *
from comcts import *
import pdb
import time
from collections import deque
from threading import Thread
import requests


import torch
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from transformers import MllamaForConditionalGeneration


API_URL = ""
API_KEY = ""

class Node:
    def __init__(self, step_text, prefix_steps, step_correctness=[], parent=None):
        """
        Initializes a node in the tree.

        step_text: Current step text.
        prefix_steps: Prefix steps text.
        step_correctness: List indicating correctness of steps (default is an empty list).
        parent: Parent node (default is None).
        """
        self.step_text = step_text  # current step text
        self.parent = parent  # parent node
        self.children = []  # children nodes
        self.visits = 0  
        self.value = 0  
        self.prefix_steps = prefix_steps    # prefix steps text
        self.step_correctness = step_correctness
        self.depth = len(self.step_correctness)
        self.text = self.prefix_steps + self.step_text

    def is_leaf(self):
        """
        Checks if the current node is a leaf node (i.e., it has no children).
        
        :return: True if the node is a leaf, False otherwise.
        """
        return len(self.children) == 0

    def best_child(self, exploration_weight=0.5):
        """
        Finds the child node with the highest UCB value in the subtree.
        
        :param exploration_weight: Weight parameter for exploration in UCB formula.
        :return: The best child node based on UCB value.
        """
        if self.is_leaf():
            return self

        best_value = -math.inf
        best_nodes = []

        for child in self.children:
            # Recursively find the best leaf node in the subtree
            best_leaf = child.best_child(exploration_weight)
            if best_leaf.is_leaf():
                ucb1 = (best_leaf.value +
                        exploration_weight * math.sqrt(math.log(best_leaf.parent.visits+1) / best_leaf.visits+1))

                # print(
                #     'UCB1 Calculation:',
                #     f'value={best_leaf.value}, parent_visits={best_leaf.parent.visits}, visits={best_leaf.visits},',
                #     f'ucb1={ucb1}, text={best_leaf.text}'
                # )

                # Update the best node list
                if ucb1 > best_value:
                    best_value = ucb1
                    best_nodes = [best_leaf]
                elif ucb1 == best_value:
                    best_nodes.append(best_leaf)

        # Return a random choice among the best nodes
        return random.choice(best_nodes)

    def add_child(self, step_text, prefix_steps, step_correctness):
        """
        Adds a child node to the current node.

        :param step_text: Step text for the child node.
        :param prefix_steps: Prefix steps for the child node.
        :param step_correctness: Correctness list for the child node.
        :return: The newly added child node.
        """
        child_node = Node(step_text, prefix_steps, step_correctness, parent=self)
        self.children.append(child_node)
        return child_node

    def update_visits(self):
        """
        Increments the visit count of the current node.
        """
        self.visits += 1

    def update_value(self, parent_visits, parent_value, new_value, new_visits):
        """
        Updates the value of the node based on the parent and new observations.

        :param parent_visits: Number of visits to the parent node.
        :param parent_value: Value of the parent node.
        :param new_value: New value to integrate into the node.
        :param new_visits: Number of new visits to incorporate.
        """
        total_visits = parent_visits + new_visits + self.visits
        self.value = (parent_visits * parent_value + new_value + self.value * self.visits) / total_visits


class CoMCTS: 
    def __init__(self, args, step_text, prefix_steps, max_iterations=15):
        self.root = Node(step_text, prefix_steps)
        self.max_iterations = max_iterations
        self.args = args

    def search(self, data, client, activated_models, model_dict, ans_file, failed_search_file):
        iteration = 0       

        ###### fix question_process #### 
        question = question_process(data)
        ###### fix question_process #### 

        # gt_answer = data["conversations"][1]['value']
        gt_answer = data["gt_answer"]

        img_path = data["image_path"]

        base64_image = encode_image(img_path)
        temperature = 0.9

        while True:
            print(f'Start the {iteration} round of search')            
            iteration += len(activated_models)

            ## Select Node
            node = self.root
            while not node.is_leaf():
                node = node.best_child(self.args.exploration_weight)
            prefix_steps = reformat_reasoning_prefix(node.text)

            # init comcts_dict
            comcts_dict = {activated_model: {'valid': 1} for activated_model in activated_models}

            ## Expansion: Generate responses for each model
            for model_name in activated_models:
                response = self._generate_model_response(
                    model_name, question, prefix_steps, base64_image, temperature, model_dict, img_path, iteration, client
                )
                comcts_dict[model_name]['response'] = response if response else ''

            # Validate responses
            for model_name in activated_models:
                response = comcts_dict[model_name]['response']
                if not check_validity(response):
                    comcts_dict[model_name]['valid'] = -1

            ## Simulation, Error Positioning
            all_correctness = self._determine_correctness(
                comcts_dict, client, question, gt_answer, activated_models, temperature
            )
            if len(all_correctness) == 0:
                print(f'{all_correctness=}')
                continue

            expand_node = node
            if 1 in all_correctness:
                comcts_dict = self._process_correct_paths(
                    model_dict, comcts_dict, expand_node, question, gt_answer, img_path, base64_image, temperature, activated_models, client, prefix_steps
                )
                comcts_dict['image'] = data['image']
                comcts_dict['question'] = question
                comcts_dict['prefix_prompt'] = prefix_steps
                comcts_dict['conversations'] = data['conversations']
                ans_file.write(json.dumps(comcts_dict) + "\n")
                ans_file.flush()
                break
            else:
                self._process_incorrect_paths(
                    model_dict, comcts_dict, expand_node, question, gt_answer, img_path, base64_image, temperature, activated_models, client, prefix_steps
                )

            if iteration >= self.max_iterations:
                for model_name in activated_models:
                    data[model_name] = {'response': comcts_dict[model_name]['response'], 'valid': comcts_dict[model_name]['valid']}
                data['prefix_prompt'] = prefix_steps
                failed_search_file.write(json.dumps(data) + "\n")
                failed_search_file.flush()
                break

    def _generate_model_response(self, model_name, question, prefix_steps, base64_image, temperature, model_dict, img_path, iteration, client):
        """Generate model-specific responses."""
        open_source_prefix_steps = "### Image Description:" if prefix_steps == '' else ''
        try:
            if model_name == 'gpt-4o':
                if iteration == 0:
                    return gpt_forward(client, PROMPT.format(question=question), base64_image, temperature)
                return gpt_forward(client, GPT_PREFIX_PROMPT.format(question=question, reasoning_prefix=prefix_steps), base64_image, temperature)
            elif 'qwen2_vl' in model_name:
                return modified_qwen_response(
                    qwen2_vl_forward(
                        model_dict[model_name]['model'],
                        model_dict[model_name]['processor'],
                        PROMPT.format(question=question),
                        open_source_prefix_steps + prefix_steps,
                        img_path
                    )
                )
            elif 'llama_vision' in model_name:

                return modified_llama_response(
                    llama_forward(
                        model_dict[model_name]['model'],
                        model_dict[model_name]['processor'],
                        PROMPT.format(question=question),
                        open_source_prefix_steps + prefix_steps,
                        img_path
                    )
                )
        except Exception as e:
            print(f"Error generating response for {model_name}: {e}")
            time.sleep(1)
            # #### fix this remove assert False part ### 
            # assert False 
            return None


    def _determine_correctness(self, comcts_dict, client, question, gt_answer, activated_models, temperature):
        """determine correctness."""
        all_correctness = []
        print(f'{comcts_dict=}')
        print(f'-'*20)
        print(f'{question=}')
        print(f'-'*20)
        print(f'{gt_answer=}')
        print(f'-'*20)
        for model_name in activated_models:
            if comcts_dict[model_name]['valid'] == -1:
                continue
            response = comcts_dict[model_name]['response']
            print(f'{response=}')
            print(f'-'*20)
            model_answer = response.split('### The final answer is:')[-1].strip()
            print(f'{model_answer=}')
            print(f'-'*20)
            while True:
                try:
                    judge_output = get_deepseek_response(JUDGE_PROMPT.format(question=question, model_answer=model_answer, gt_answer=gt_answer),  API_KEY, API_URL, temperature)
                    break
                except Exception as e:
                    time.sleep(1)
                    print(e)
            print(f'{judge_output=}')
            print(f'-'*20)
            is_correct = get_correctness(judge_output)

            all_correctness.append(is_correct)
            comcts_dict[model_name]['is_correct'] = is_correct
        
        return all_correctness
        

    def _process_correct_paths(self, model_dict, comcts_dict, expand_node, question, gt_answer, img_path, base64_image, temperature, activated_models, client, prefix_steps):
        """Handle scenarios where correct paths are found."""
        for model_name in activated_models:
            if comcts_dict[model_name]['valid'] == -1:
                continue
            depth = get_depth(comcts_dict[model_name]['response'])

            is_correct = comcts_dict[model_name]['is_correct']
            while True:
                max_try_count = 3
                try_count = 0
                try:
                    step_correctness_response = get_deepseek_response(LOCATE_ERROR_PROMPT_WO_IMAGE.format(question=question, reasoning=comcts_dict[model_name]['response'], gt=gt_answer), API_KEY, API_URL, temperature)
                    step_correctness = step_correctness_to_list(step_correctness_response, depth=depth)
                    if step_correctness != [-2] or try_count > max_try_count:
                        break
                    try_count += 1
                except Exception as e:
                    time.sleep(1)
                    print(e)

            if step_correctness == [-2]:
                comcts_dict[model_name]['valid'] = -1

            prefix_steps_depth = get_depth(expand_node.text)
            suffix_steps_depth = get_depth(comcts_dict[model_name]['response']) - 1 # remove final answer
            new_step = ''
            current_node = expand_node
            new_prefix_steps = prefix_steps
            for i in range(prefix_steps_depth, suffix_steps_depth):  
                new_prefix_steps = new_prefix_steps + new_step
                new_step = get_step(comcts_dict[model_name]['response'], i+1)
                current_node = current_node.add_child(step_text=new_step, prefix_steps=new_prefix_steps, step_correctness=step_correctness[:(i+1)])
            
            ## Backpropagation
            # leaf node
            up_node = current_node
            depth_diff = suffix_steps_depth - prefix_steps_depth
            step_value = []
            for idx in range(suffix_steps_depth, 0, -1):
                if idx > prefix_steps_depth:
                    # new node
                    new_value = sum(step_correctness[prefix_steps_depth:idx])
                    up_node.update_value(parent_visits=expand_node.visits, parent_value=expand_node.value, new_value=new_value, new_visits=idx-prefix_steps_depth)
                    up_node.update_visits()
                else:
                    new_value = step_correctness[idx-1]
                    up_node.update_value(parent_visits=up_node.parent.visits, parent_value=up_node.parent.value, new_value=new_value, new_visits=1)
                    up_node.update_visits()

                step_value.insert(0, round(up_node.value,3))
                up_node = up_node.parent

            value = (current_node.value +
                    self.args.exploration_weight * math.sqrt(math.log(current_node.parent.visits+1) / current_node.visits+1))

            comcts_dict[model_name] = {'response': comcts_dict[model_name]['response'], "value": round(value,3), 'step_value': step_value, "is_correct": is_correct, 'valid': comcts_dict[model_name]['valid']}
        
        return comcts_dict
        

    def _process_incorrect_paths(self, model_dict, comcts_dict, expand_node, question, gt_answer, img_path, base64_image, temperature, activated_models, client, prefix_steps):
        """Handle scenarios where correct paths are not found."""
        for model_name in activated_models:
            if comcts_dict[model_name]['valid'] == -1:
                continue

            depth = get_depth(comcts_dict[model_name]['response'])
            while True:
                max_try_count = 3
                try_count = 0
                try:
                    # step_correctness_response = gpt_forward(client, LOCATE_ERROR_PROMPT.format(question=question, reasoning=comcts_dict[model_name]['response'], gt=gt_answer), base64_image, temperature)
                    step_correctness_response = get_deepseek_response(LOCATE_ERROR_PROMPT_WO_IMAGE.format(question=question, reasoning=comcts_dict[model_name]['response'], gt=gt_answer), API_KEY, API_URL, temperature)
                    step_correctness = step_correctness_to_list(step_correctness_response, depth=depth)
                    if step_correctness != [-2] or try_count > max_try_count:
                        break
                    try_count += 1
                except Exception as e:
                    time.sleep(1)
                    print(e)

            if len(step_correctness) == 0:
                comcts_dict[model_name]['valid'] = -1
                continue
            if step_correctness[0] == -2:
                comcts_dict[model_name]['valid'] = -1
                continue

            comcts_dict[model_name] = {'response': comcts_dict[model_name]['response'], 'step_correctness': step_correctness, 'depth':depth, 'valid': comcts_dict[model_name]['valid']}
            

            # Prune the first node smaller than the threshold.
            comcts_dict[model_name] = prune_tree(comcts_dict[model_name], start_index=get_depth(expand_node.text), threshold=self.args.threshold)
            if comcts_dict[model_name]['valid'] == -1:
                up_node = expand_node
                while up_node.parent is not None:
                    # do not update the root node
                    up_node.update_visits()
                    up_node = up_node.parent
                continue

            pruned_response = comcts_dict[model_name]['response']
            updated_step_correctness = comcts_dict[model_name]['step_correctness']

            # add nodes
            prefix_steps_depth = get_depth(expand_node.text)
            pruned_steps_depth = get_depth(pruned_response)
            new_step = ''
            current_node = expand_node
            new_prefix_steps = prefix_steps
            # print(prefix_steps_depth, pruned_steps_depth, comcts_dict)
            for i in range(prefix_steps_depth, pruned_steps_depth):
                new_prefix_steps = new_prefix_steps + new_step
                new_step = get_step(pruned_response, i+1)
                current_node = current_node.add_child(step_text=new_step, prefix_steps=new_prefix_steps, step_correctness=updated_step_correctness[:(i+1)])

            ## Backpropagation
            # bottom-up update node
            up_node = current_node
            depth_diff = pruned_steps_depth - prefix_steps_depth
            for idx in range(pruned_steps_depth, 0, -1):
                if idx > prefix_steps_depth:
                    # new node
                    new_value = sum(updated_step_correctness[prefix_steps_depth:idx])
                    up_node.update_value(parent_visits=expand_node.visits, parent_value=expand_node.value, new_value=new_value, new_visits=idx-prefix_steps_depth)
                    up_node.update_visits()
                else:
                    new_value = updated_step_correctness[idx-1]
                    up_node.update_value(parent_visits=up_node.parent.visits, parent_value=up_node.parent.value, new_value=new_value, new_visits=1)
                    up_node.update_visits()

                up_node = up_node.parent