import re
import random
import numpy as np
import ast
import operator
import pdb
import json
import sys
import os
import traceback
import math
import concurrent.futures
from itertools import combinations
sys.path.append('./')
import copy
from src.utils import ndcg_at_k, get_rank

# ANCE
retrieval = os.environ.get('RETRIEVAL_TYPE', 'xxx')
IP_ADDRESS = os.environ.get('IP_ADDRESS', 'xxx')
IP_PORT = os.environ.get('IP_PORT', 'xxx')
USE_HNSW = os.environ.get('USE_HNSW', '1')
# Call search API
server = f"http://{IP_ADDRESS}:{IP_PORT}"
if retrieval == 'bge':
    from search_launch.dense_bge.search_client import FaissSearchClient
    client_search = FaissSearchClient(server)
elif retrieval == 'ance':
    from search_launch.ance.search_client import FaissSearchClient
    client_search = FaissSearchClient(server)
elif retrieval == 'bm25':
    from search_launch.lucene.search_client import LuceneSearchClient
    client_search = LuceneSearchClient(server)
else:
    raise Exception("Please specify the correct retriever")
print(f"*************************Using {retrieval} retrieval on {server}*************************")

# Filter difficult ones
def filter_standard(data):
    # if data['_score'] >= 1.5:
    rank = data['_rank']
    if rank <= 5 and rank != -1:
        return False
    return True

def extract_solution(solution_str):
    """Extract the equation from the solution string."""
    # Remove everything before the first "Assistant:"
    if "Assistant:" in solution_str:
        processed_str = solution_str.split("Assistant:", 1)[1].strip()
    elif "<|im_start|>assistant" in solution_str:
        processed_str = solution_str.split("<|im_start|>assistant", 1)[1].strip()
    else:
        print("[Error] Failed to locate model response header")
        return None, processed_str

    # Regular expression to find the last occurrence of <answer>...</answer>
    answer_pattern = r'<answer>(.*?)</answer>'
    matches = re.findall(answer_pattern, processed_str, re.DOTALL)  # Use re.DOTALL to match multiline content

    if matches:
        return matches[-1].strip(), processed_str  # Return the last matched answer
    else:
        print("[Error] No valid answer tags found")
        return None, processed_str
        

def validate_response_structure(processed_str: str, do_print: bool) -> bool:
    """Performs comprehensive validation of response structure.
    
    Args:
        processed_str: Processed response string from the model
        
    Returns:
        Boolean indicating whether all formatting requirements are met
    """
    # do_print = False # do not print
    if do_print:
        print("\n[Structure Validation]")
    validation_passed = True

    # processed_str = '<think> </think>' + processed_str
    
    # Check required tags
    tags = {
        'think_start': ('<think>', 1),
        'think_end': ('</think>', 1),
        'answer_start': ('<rewrite>', 1),
        'answer_end': ('</rewrite>', 1)
    }

    positions = {}
    for tag_name, (tag_str, expected_count) in tags.items():
        count = processed_str.count(tag_str)
        positions[tag_name] = pos = processed_str.find(tag_str)
        
        if do_print:
            print(f"  {tag_str}: count={count}, position={pos}")
        
        if count != expected_count:
            if do_print:
                print(f"  [Error] {tag_str} appears {count} times (expected {expected_count})")
            validation_passed = False

    # Verify tag order
    if (positions['think_start'] > positions['think_end'] or
        positions['think_end'] > positions['answer_start'] or
        positions['answer_start'] > positions['answer_end']):
        if do_print:
            print("  [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
        validation_passed = False
    else:
        if do_print:
            print("  Tag sequence validation passed")
    
    return validation_passed

def check_json_format(json_str, do_print=False):
    """Check if the given string is a valid JSON and follows the expected structure.
    
    Requirements:
    1. Must be a valid JSON that parses to a dictionary
    2. Must have exactly one key: 'query'
    3. The value of 'query' must be a string
    """
    try:
        if not json_str:
            if do_print:
                print("[Error] Empty JSON string")
            return False
        
        data = json.loads(json_str)
        
        # Check if parsed data is a dictionary
        if not isinstance(data, dict):
            if do_print:
                print(f"[Error] JSON should parse to dict, got {type(data)}: {data}")
            return False
        
        # Check if it has exactly one key
        if len(data) != 1:
            if do_print:
                print(f"[Error] JSON should have exactly 1 key, got {len(data)} keys: {list(data.keys())}")
            return False
        
        # Check if the only key is 'query'
        if 'query' not in data:
            if do_print:
                print(f"[Error] JSON should have 'query' key, got: {list(data.keys())}")
            return False
        
        # Check if the value of 'query' is a string
        if not isinstance(data['query'], str):
            if do_print:
                print(f"[Error] Value of 'query' should be string, got {type(data['query'])}: {data['query']}")
            return False

        return True
    except json.JSONDecodeError as e:
        if do_print:
            print(f"[Error] JSON decoding failed: {e}")
        return False



def retriver_items(query, top_k=300, threads=16):
    """Retrieve items from the search system."""
    # results = search_system.search(query, top_k=top_k, threads=threads)
    # Modified to call API
    # results, search_time = client_search.batch_search(queries, top_k=20)
    results, search_time = client_search.search(query, top_k, threads)
    return results

def cal_score_1(rank_list, k = 100, eta=0.6):
    # rank_list_max = [max(rank_list)]
    total_score = 0.0
    score_list = []
    for rank in rank_list:
        score = 0.0
        if rank < 1:
            score = 0
        elif rank < 10:
            score = 1 + (10 - rank) / 9
        else:
            score = (100 - rank) / 90
        score = max(score, 0)
        score_list.append(score)
    score_list.sort()
    if len(score_list) >= 2:
        total_score = 0.0
        for i, s in enumerate(score_list, start=1):
            total_score += (eta ** i) * s
    else:
        total_score = score_list[0]
    return total_score

def cal_score(rank_list, k):
    '''
    rank_list = [1] or [1, 2] shows the ranking of each target query

    Here can use difference cal_score_function
    now we use cal_score_1, which is acculuate of all rank_list
    '''
    return cal_score_1(rank_list, k)

def cal_h3(score_list): # Top 10
    res = 0
    mn_score = 100000
    for score, rank in score_list:
        res += 1 / rank
        mn_score = min(mn_score, score)
    return 1 / res, mn_score

def rerank(results_init, top_k) -> list[int]:
    '''
    args:
    results:list[list[id, score]]
    return:
    Directly return the document id, representing the rank, e.g., [123, 456, 789]

    For dense retrieval (bge, ance) with HNSW Index, the smaller the better
    For sparse retrieval (bm25) and dense retrieval with Flat Index, the larger the better, so currently try to take the reciprocal directly
    '''
    # Here we need to copy it, avoid modifying the original
    results = results_init
    if retrieval == 'bm25' or USE_HNSW == '0':
        results = [row.copy() for row in results_init]
        for result in results:
            for i in range(len(result)):
                idx, score = result[i]
                result[i] = [idx, 1 / score]

    document_hash = {}
    for i, result in enumerate(results): # For each document,
        rank = 0
        for doc_id, dis in result:
            rank += 1
            if doc_id not in document_hash:
                # The first dimension represents sum(dis), the second dimension represents whether there is i
                # Change to list, representing the score of each
                document_hash[doc_id] = [] 
            # Add distance
            # document_hash[doc_id].append(dis)
            # Add distance and rank
            document_hash[doc_id].append((dis, rank))
    docs_score = {}
    for doc_id, value_list in document_hash.items():
        rank, score = cal_h3(value_list)
        docs_score[doc_id] = (rank, score)
    sorted_items = sorted(docs_score.items(), key=lambda x: (x[1][0], x[1][1]))[:top_k]
    # sorted_items = sorted(document_hash.items(), key=lambda x: x[1][0])[:top_k]
    pred_results = [item[0] for item in sorted_items]
    # print(sorted_items)
    return pred_results


def calculate_similarity(queries) -> list:
    if len(queries) >= 2:
        query_pairs = list(combinations(queries, 2))
        def calculate_pair_similarity(pair):
            query1, query2 = pair
            similarity, _ = client_sim.calculate_similarity(query1, query2)
            return similarity

        with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
            similarities = list(executor.map(calculate_pair_similarity, query_pairs))
        return similarities
    else:
        raise Exception("Similarity calculation requires at least two queries")

import itertools
def calculate_score_foreach_combination(results, target, top_k, queries = [], original_query = "", original_query_results = []) -> list[tuple[float, int, list, dict]]:
    '''
    calculate score for all subsets
    args:
    results，
    target, i.e., gt
    queries used for calculating similarity
    original_query for convenient use of rerank models later
    return:
    score, rank, pred_results_list[-1]
    score is the processed score
    rank is the rank corresponding to the generated query calculation, not the highest (lowest) rank
    pred_results_list[-1] is the sorted recalled documents, example
    '''
    pred_results_list = [[]] # Fill 0 with empty, pred_results_list[i] represents combination of length i
    length = len(results)
    _ = 1 << length


    negative_rank_list = [-1 for i in range(len(target))] # Negative samples, i.e., all target ranks are -1


    results_list = []
    for i in range(_):
        if i == 0:
            results_list.append([])
            continue
        sub_results = []
        for j in range(length):
            if (1 << j) & i:
                sub_results.append(results[j]) # Enumerate subsets, index starts from 0
        results_list.append(sub_results)
    # Get respective top-k documents, set 0 to empty
    pred_results_list = []
    for i, results in enumerate(results_list):
        if i != 0:
            pred_results_list.append(rerank(results, top_k))
        else:
            pred_results_list.append([])
    # Get respective ranks, set 0 to negative
    # rank_score_list = [cal_score(pred_results, target, top_k) for pred_results in pred_results_list]
    rank_list = []
    for i, pred_results in enumerate(pred_results_list):
        if i != 0:
            rank = get_rank(pred_results, target, top_k)
            if isinstance(rank, int): # Length is one
                rank = [rank]
            rank_list.append(rank)
        else:
            rank_list.append(negative_rank_list)
    # Get respective rank_scores, set 0 to 0
    # use[0, 2]
    rank_score_list = []
    for i, rank in enumerate(rank_list):
        if i != 0:
            rank_score_list.append(cal_score(rank, top_k))
        else:
            rank_score_list.append(0)
    # Since best_rank has two, so below use best_score to measure
    # i.e., best_rank_query_count actually is best_score_query_count
    more_returns = {}
    # rank is the rank corresponding to the current query, change to score to measure
    rank = rank_list[-1]
    score = rank_score_list[-1]
    # best_rank is the optimal rank
    best_rank = rank
    best_score = score
    for i, item in enumerate(rank_score_list):
        if item > best_score:
            best_rank = rank_list[i]
            best_score = item
    more_returns['best_rank'] = best_rank
    more_returns['rank'] = rank
    more_returns['best_rank_is_valid'] = True if best_rank != negative_rank_list else False # Within top-k
    # rank is valid if it is not in the top-k
    more_returns['rank_is_valid'] = True if rank != negative_rank_list else False # Within top-k
    # query length
    more_returns['query_count'] = length
    # best_rank corresponding to the smallest subset length, not the number of best_rank, only when best_rank != -1, otherwise=-1
    more_returns['best_rank_query_count'] = -1
    min_count, max_val = 10000, -10000
    min_idx = -1
    for i, score in enumerate(rank_score_list):
        def get_count(w):
            cnt = 0
            while(w):
                if w & 1:
                    cnt += 1
                w >>= 1
            return cnt
        if score != 0 and score >= max_val:
            if score > max_val: # Rank closer, update min_count
                max_val = score
                min_count = get_count(i)
                min_idx = i
            else: # Rank same, take the smallest, i.e., smallest subset
                w = get_count(i)
                if w < min_count:
                    min_count = w
                    min_idx = i
    if max_val != -10000:
        more_returns['best_rank_query_count'] = min_count
    # Calculate the query combination corresponding to the best rank, if there is no best rank, then ""
    more_returns['best_query'] = ""
    if min_idx != -1:
        best_query = ""
        w = min_idx
        q = 0
        while(w):
            if w & 1:
                if best_query != "":
                    best_query += '%%'
                best_query += queries[q]
            w >>= 1
            q += 1
        more_returns['best_query'] = best_query

    more_returns['is_best_rank'] = True if min_idx == (len(rank_list) - 1) else False # The current is the only best rank, 1 when and only when the full set is the only smallest rank
    # Calculate the highest score
    mx_score = -1
    for i in range(_):
        mx_score = max(mx_score, rank_score_list[i])
    # Calculate the final score
    # score = rank_score_list[-1]
    score = mx_score

    # Calculate the original query score
    # 1. Calculate the rank
    if original_query != "":
        rank_original = get_rank(pred_results, target, top_k)
        if isinstance(rank_original, int): # Length is one
            rank_original = [rank_original]
        # 2. Calculate the score
        score_original = cal_score(rank_original, top_k)
        more_returns['original_query_rank'] = rank_original
        more_returns['original_query_score'] = score_original
        # If the original query is worse, big penalty
        if mx_score < score_original:
            score = 0
        elif score < score_original:
            # score = max(score - length * 0.1, 0)
            score = 0
    
    return score, rank, pred_results_list[-1], more_returns

def get_query(response):
    try:
        response = response.split('<rewrite>')[1]
        response = response.split('</rewrite>')[0].strip()
        return response
    except:
        # Compatible with non-think mode, if no <rewrite>, return directly
        return response.strip()

def results_is_empty(results):
    for result in results:
        if result == []:
            return True
    return False

def calculate_answer_score_new(json_str, label, scores, top_k, extra_info, do_print=False) -> tuple[int, int]:
    """Calculate answer score based on final_prediction idx.
    return (score, rank, more_returns)
    """
    score = -1
    # original_query
    add_original_query = False
    try:
        original_query = extra_info['question']
        if original_query.strip() != "":
            add_original_query = True
    except:
        pass
    negative_rank_list = [-1 for i in range(len(label))] # Negative samples, i.e., all target ranks are -1


    answer_score = 0
    has_multi = False
    has_empty = False
    target = label
    query = get_query(json_str)
    # Check if the json format is valid
    is_valid_json = check_json_format(query, do_print = False)
    if not is_valid_json:
        return -3, negative_rank_list, {}
    # parser query to string format
    query = json.loads(query)['query']
    if '%%' in query:
        has_multi = True
        queries = query.split('%%')
        for item in queries:
            if item.strip() == "" or '%' in item:
                has_empty = True
            item = item.strip()
    else: # Directly use the query as queries
        queries = [query]
    if len(queries) > 4 or has_empty: # Too many or empty
        return -1, negative_rank_list, {}
    # Above filters out samples that do not meet the format requirements, below starts calculating scores directly

    # Add original query for parallel retrieval, but pass separately
    if add_original_query:
        queries.append(original_query)

    # Parallelly get the recalled documents for each query
    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
        results = list(executor.map(
            lambda query: retriver_items(query, top_k=top_k, threads=32), 
            queries
        ))    


    # Split, separately input, keep list
    if add_original_query:
        original_query_results = [results[-1]]
        results = results[:-1]
    else:
        original_query = ""
        original_query_results = []
    
    flag = results_is_empty(results)
    if flag: # Empty results exist
        return 0, negative_rank_list, {}


    try:
        score, rank, pred_results, more_returns = calculate_score_foreach_combination(results, target, top_k, queries, original_query, original_query_results)
    except IndexError:
        # Get complete error stack information
        error_msg = traceback.format_exc()
        print(f"Index error occurred, details:\n{error_msg}")
    if do_print:
        if 'question' in extra_info:
            print("Original query: ", extra_info['question'])
        print("Rewrite result: ", query)
        print(f"target: {target}")
        print(f"results: {pred_results[:10]}")
        print(f"answer_score: {score}")
        print(f"rank: {rank}")
        print(f"Current generated query count and whether it's maximum: {more_returns['is_best_rank']}")
        print('\n\n')
    return score, rank, more_returns

def compute_score(data_source, solution_str, ground_truth, extra_info=None):

    total_score = 0
    do_print = random.randint(1, 256) == 1
    top_k = 100
    split = extra_info['split'] # train or test
    if split == 'test':
        top_k = 10
    scores = [1 for i in range(len(ground_truth))] # This parameter is not used, representing the weight of each document, but currently all data has the same weight, so directly use 1, not using this parameter

    negative_rank_list = [-1 for i in range(len(ground_truth))] # Negative samples, i.e., all target ranks are -1

    label = ground_truth # gt ['32331'] document id
    answer_text = solution_str # Directly use the output as the question
    # Check if there is think, if there is <think> means think mode, if there is no </think> then punish
    # If it is think data, check the think format
    if 'think' in data_source:
        format_is_valid = validate_response_structure(answer_text, do_print)
        if not format_is_valid:
            return {
                "score": -3.0,
                "acc": 0,
                "pred": answer_text,
                # "rank": [-1, -1],
                # "more_returns": more_returns
            }
    # Otherwise, directly enter.
    total_score, rank, more_returns = calculate_answer_score_new(answer_text, label, scores, top_k, extra_info, do_print)
    acc = 1 if rank != negative_rank_list else 0
    return {
        "score": total_score,
        "acc": acc,
        "pred": answer_text,
        # "rank": rank,
        # "more_returns": more_returns
    }

if __name__ == '__main__':
    solution_str = """<|im_start|>assistant:  <answer>{"query": "Microstructural development of human"}</answer>"""
    solution_str = """{"query": "Is rugby league player Alan Davies married"}"""
    ground_truth = {'target': '4983'}
    ground_truth = [13372305]
    scores = compute_score("", solution_str, ground_truth)
    print(scores)