import re
import random
import numpy as np
import ast
import operator
import pdb
import json
import sys
import os
import traceback
import concurrent.futures
from itertools import combinations
sys.path.append('./')
import copy
from src.utils import ndcg_at_k, get_rank
from utils import ndcg_at_k_rank_list, map_at_k_rank_list
# Switch to calling API, BGE
# from similarity_launch.dense_bge.similarity_client import BGESimilarityClient
# # Call search API
# client_search = FaissSearchClient("http://10.11.1.78:8000")
# # Call similarity calculation API, using the same model as above
# client_sim = BGESimilarityClient(("http://10.11.1.78:8001"))

# ANCE
retrieval = os.environ.get('RETRIEVAL_TYPE', 'xxx')
IP_ADDRESS = os.environ.get('IP_ADDRESS', 'xxx')
IP_PORT = os.environ.get('IP_PORT', 'xxx')
USE_HNSW = os.environ.get('USE_HNSW', '1')
# Call search API
server = f"http://{IP_ADDRESS}:{IP_PORT}"
if retrieval == 'bge':
    from search_launch.dense_bge.search_client import FaissSearchClient
    client_search = FaissSearchClient(server)
elif retrieval == 'ance':
    from search_launch.ance.search_client import FaissSearchClient
    client_search = FaissSearchClient(server)
elif retrieval == 'bm25':
    from search_launch.lucene.search_client import LuceneSearchClient
    client_search = LuceneSearchClient(server)
else:
    raise Exception("Please specify the correct retriever")
print(f"*************************Using {retrieval} retrieval on {server}*************************")

# Filter difficult ones
def filter_standard(data):
    if data['_score'] >= 1.7:
        return False
    return True
    rank = data['_rank']
    if rank <= 5 and rank != -1:
        return False
    return True

def extract_solution(solution_str):
    """Extract the equation from the solution string."""
    # Remove everything before the first "Assistant:"
    if "Assistant:" in solution_str:
        processed_str = solution_str.split("Assistant:", 1)[1].strip()
    elif "<|im_start|>assistant" in solution_str:
        processed_str = solution_str.split("<|im_start|>assistant", 1)[1].strip()
    else:
        print("[Error] Failed to locate model response header")
        return None, processed_str

    # Regular expression to find the last occurrence of <answer>...</answer>
    answer_pattern = r'<answer>(.*?)</answer>'
    matches = re.findall(answer_pattern, processed_str, re.DOTALL)  # Use re.DOTALL to match multiline content

    if matches:
        return matches[-1].strip(), processed_str  # Return the last matched answer
    else:
        print("[Error] No valid answer tags found")
        return None, processed_str
        

def validate_response_structure(processed_str: str, do_print: bool) -> bool:
    """Performs comprehensive validation of response structure.
    
    Args:
        processed_str: Processed response string from the model
        
    Returns:
        Boolean indicating whether all formatting requirements are met
    """
    # do_print = False # do not print
    if do_print:
        print("\n[Structure Validation]")
    validation_passed = True

    # processed_str = '<think> </think>' + processed_str
    
    # Check required tags
    tags = {
        'think_start': ('<think>', 1),
        'think_end': ('</think>', 1),
        'answer_start': ('<rewrite>', 1),
        'answer_end': ('</rewrite>', 1)
    }

    positions = {}
    for tag_name, (tag_str, expected_count) in tags.items():
        count = processed_str.count(tag_str)
        positions[tag_name] = pos = processed_str.find(tag_str)
        
        if do_print:
            print(f"  {tag_str}: count={count}, position={pos}")
        
        if count != expected_count:
            if do_print:
                print(f"  [Error] {tag_str} appears {count} times (expected {expected_count})")
            validation_passed = False

    # Verify tag order
    if (positions['think_start'] > positions['think_end'] or
        positions['think_end'] > positions['answer_start'] or
        positions['answer_start'] > positions['answer_end']):
        if do_print:
            print("  [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
        validation_passed = False
    else:
        if do_print:
            print("  Tag sequence validation passed")
    
    return validation_passed


def check_json_format(json_str, do_print=False):
    """Check if the given string is a valid JSON and follows the expected structure."""
    try:
        if not json_str:
            if do_print:
                print("[Error] Empty JSON string")
            return False
        
        data = json.loads(json_str)
        
        # Required keys
        required_keys = {"query"}
        if not all(key in data for key in required_keys):
            if do_print:
                print("[Error] Missing required keys in JSON")
            return False

        return True
    except json.JSONDecodeError:
        if do_print:
            print("[Error] JSON decoding failed")
        return False

def retriver_items(query, top_k=300, threads=16):
    """Retrieve items from the search system."""
    # results = search_system.search(query, top_k=top_k, threads=threads)
    # Changed to call API
    # results, search_time = client_search.batch_search(queries, top_k=20)
    results, search_time = client_search.search(query, top_k, threads)
    return results

def cal_score_1(rank_list, k = 100):
    total_score = 0.0
    for rank in rank_list:
        score = 0.0
        if rank < 1:
            score = 0
        elif rank < 10:
            score = 1 + (10 - rank) / 9
        else:
            score = (100 - rank) / 90
        score = max(score, 0)
        total_score += score
    total_score /= len(rank_list) # Norm
    return total_score

def cal_score(rank_list, k):
    '''
    rank_list = [1] or [1, 2] shows the ranking of each target query

    Here can use difference cal_score_function
    now we use cal_score_1, which is acculuate of all rank_list
    '''
    return cal_score_1(rank_list, k)

def cal_h3(score_list, rank_threshold=10): # Top 10
    res = 0
    mn_rank = 100000
    mn_score = 100000
    for score, rank in score_list:
        if rank <= rank_threshold:
            res += 1 / rank
        mn_score = min(mn_score, score)
        mn_rank = min(mn_rank, rank)
    if res == 0:
        res = 1 / mn_rank
    return 1 / res, mn_score

def rerank(results_init, top_k) -> list[int]:
    '''
    args:
    results:list[list[id, score]]
    return:
    Directly return the document id, representing the rank, e.g., [123, 456, 789]

    For dense retrieval (bge, ance) using HNSW Index (distance-based), the smaller value is better
    For sparse retrieval (bm25), the larger value is better, so take the reciprocal here
    '''
    # If bm25 retriever, take reciprocal
    # Copy to avoid modifying the original
    results = results_init
    # if retrieval == 'bm25':
    if retrieval == 'bm25' or USE_HNSW == '0':
        results = [row.copy() for row in results_init]
        for result in results:
            for i in range(len(result)):
                idx, score = result[i]
                result[i] = [idx, 1 / score]

    document_hash = {}
    for i, result in enumerate(results): # For each document
        rank = 0
        for doc_id, dis in result:
            rank += 1
            if doc_id not in document_hash:
                # Use list to store each (distance, rank)
                document_hash[doc_id] = [] 
            # Add distance and rank
            document_hash[doc_id].append((dis, rank))
    docs_score = {}
    for doc_id, value_list in document_hash.items():
        rank, score = cal_h3(value_list)
        docs_score[doc_id] = (rank, score)
    sorted_items = sorted(docs_score.items(), key=lambda x: (x[1][0], x[1][1]))[:top_k]
    # sorted_items = sorted(document_hash.items(), key=lambda x: x[1][0])[:top_k]
    pred_results = [item[0] for item in sorted_items]
    # print(sorted_items)
    return pred_results


def calculate_similarity(queries) -> list:
    if len(queries) >= 2:
        query_pairs = list(combinations(queries, 2))
        def calculate_pair_similarity(pair):
            query1, query2 = pair
            similarity, _ = client_sim.calculate_similarity(query1, query2)
            return similarity

        with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
            similarities = list(executor.map(calculate_pair_similarity, query_pairs))
        return similarities
    else:
        raise Exception("Similarity calculation requires at least two queries")

def is_r10(rank_list, score):
    return map_at_k_rank_list(rank_list, 10)




import itertools
def calculate_score_foreach_combination(results, target, top_k, queries = [], original_query = "", original_query_results = []) -> list[tuple[float, int, list, dict]]:
    '''
    Calculate scores for all subsets
    args
    results,
    target, i.e., gt

    queries used for calculating similarity
    original_query for convenient use of rerank models later
    return
    score, rank, pred_results_list[-1]
    score is the processed score
    rank is the rank corresponding to the generated query calculation, not the highest (lowest) rank
    pred_results_list[-1] is the sorted recalled documents, example
    '''

    pred_results_list = [[]] # Fill 0 with empty, pred_results_list[i] represents combination of length i
    length = len(results)
    _ = 1 << length


    negative_rank_list = [-1 for i in range(len(target))] # Negative samples, i.e., all target ranks are -1


    results_list = []
    for i in range(_):
        if i == 0:
            results_list.append([])
            continue
        sub_results = []
        for j in range(length):
            if (1 << j) & i:
                sub_results.append(results[j]) # Enumerate subsets, index starts from 0
        results_list.append(sub_results)
    # Get respective top-k documents, set 0 to empty
    pred_results_list = []
    for i, results in enumerate(results_list):
        if i != 0:
            pred_results_list.append(rerank(results, top_k))
        else:
            pred_results_list.append([])
    # Get respective ranks, set 0 to negative
    # rank_score_list = [cal_score(pred_results, target, top_k) for pred_results in pred_results_list]
    rank_list = []
    for i, pred_results in enumerate(pred_results_list):
        if i != 0:
            rank = get_rank(pred_results, target, top_k)
            if isinstance(rank, int): # Length is one
                rank = [rank]
            rank_list.append(rank)
        else:
            rank_list.append(negative_rank_list)
    # Get respective rank_scores, set 0 to 0
    # Use [0, 2]
    rank_score_list = []
    for i, rank in enumerate(rank_list):
        if i != 0:
            rank_score_list.append(cal_score(rank, top_k))
        else:
            rank_score_list.append(0)

    # Since there can be multiple best_rank values, use best_score to measure instead
    # That is, best_rank_query_count is actually best_score_query_count
    more_returns = {}
    # rank is the rank corresponding to the current query; use score to measure
    rank = rank_list[-1]
    score = rank_score_list[-1]
    # best_rank is the optimal rank
    best_rank = rank
    best_score = score
    for i, item in enumerate(rank_score_list):
        if item > best_score:
            best_rank = rank_list[i]
            best_score = item
    more_returns['best_rank'] = best_rank
    more_returns['rank'] = rank
    more_returns['best_rank_is_valid'] = True if best_rank != negative_rank_list else False # Within top-k
    more_returns['rank_is_valid'] = True if rank != negative_rank_list else False # Within top-k
    # Query length
    more_returns['query_count'] = length
    # Minimum subset length corresponding to best_rank (not the count of ties). Only valid when best_rank != -1; otherwise -1
    more_returns['best_rank_query_count'] = -1
    min_count, max_val = 10000, -10000
    min_idx = -1
    # Maximum
    max_count = -1
    for i, score in enumerate(rank_score_list):
        def get_count(w):
            cnt = 0
            while(w):
                if w & 1:
                    cnt += 1
                w >>= 1
            return cnt
        if score != 0 and score >= max_val:
            if score > max_val: # Better rank, update min_count
                max_val = score
                min_count = get_count(i)
                min_idx = i
                max_count = max(max_count, get_count(i))
            else: # Same rank, choose the smallest subset (fewest queries)
                w = get_count(i)
                max_count = max(max_count, w)
                if w < min_count:
                    min_count = w
                    min_idx = i
    if max_val != -10000:
        # more_returns['best_rank_query_count'] = min_count
        more_returns['best_rank_query_count'] = max_count # Changed here
    # Compute the query combination corresponding to the best rank; if none, set to ""
    more_returns['best_query'] = ""
    if min_idx != -1:
        best_query = ""
        w = min_idx
        q = 0
        while(w):
            if w & 1:
                if best_query != "":
                    best_query += '%%'
                best_query += queries[q]
            w >>= 1
            q += 1
        more_returns['best_query'] = best_query
    more_returns['is_only_best_rank'] = True if min_idx == (len(rank_list) - 1) else False # Current is the only top-most rank iff the full set is the unique minimum rank
    more_returns['is_best_rank'] = True if best_score == rank_score_list[-1] and best_score > 0 else False # Current is among the best ranks and best_score > 0
    # Compute each single-query R10 value
    # more_returns['single_query'] = ['q1', 'q2'], representing each single query
    # more_returns['single_query_r10'] = [1, 0], indicating whether each single query's rank <= 10 and != -1
    more_returns['single_query_r10'] = []
    more_returns['single_query_rank_list'] = []
    more_returns['single_query'] = []
    for i, single_result in enumerate(results):
        more_returns['single_query'].append(queries[i])
        rank = rank_list[1 << i] # rank = [1] here rank is list
        isr10 = is_r10(rank, rank_score_list[1 << i])
        more_returns['single_query_r10'].append(isr10)
        more_returns['single_query_rank_list'].append(rank)
    # Minimum
    more_returns['worse_query_r10'] = min(more_returns['single_query_r10'])
    more_returns['best_query_r10'] = max(more_returns['single_query_r10'])
    more_returns['merged_query_r10'] = is_r10(rank_list[-1], rank_score_list[1 << i])



    # Compute maximum score
    mx_score = -1
    for i in range(_):
        mx_score = max(mx_score, rank_score_list[i])
    # Compute final score
    score = rank_score_list[-1]
    # score = mx_score

    # Compute original query score
    # 1. Compute rank
    if original_query != "":
        rank_original = get_rank(pred_results, target, top_k)
        if isinstance(rank_original, int): # Length is one
            rank_original = [rank_original]
        # 2. Compute score
        score_original = cal_score(rank_original, top_k)
        more_returns['original_query_rank'] = rank_original
        more_returns['original_query_score'] = score_original
        # Heavy penalty if worse than original query
        if mx_score < score_original:
            score = 0
        elif score < score_original:
            # score = max(score - length * 0.1, 0)
            score = 0
    return score, rank, pred_results_list[-1], more_returns

def get_query(response):
    try:
        response = response.split('<rewrite>')[1]
        response = response.split('</rewrite>')[0].strip()
        return response
    except:
        # Compatible with non-think mode; if no <rewrite>, return directly
        return response.strip()

def results_is_empty(results):
    for result in results:
        if result == []:
            return True
    return False

def calculate_answer_score_new(json_str, label, scores, top_k, extra_info, do_print=False) -> tuple[int, int]:
    """Calculate answer score based on final_prediction idx.
    return (score, rank, more_returns)
    """
    score = -1
    # original_query
    add_original_query = False
    try:
        original_query = extra_info['question']
        if original_query.strip() != "":
            add_original_query = True
    except:
        pass
    negative_rank_list = [-1 for i in range(len(label))] # Negative samples, i.e., all target ranks are -1


    answer_score = 0
    has_multi = False
    has_empty = False
    target = label
    query = get_query(json_str)
    # Check whether JSON format is valid
    is_valid_json = check_json_format(query, do_print = False)
    if not is_valid_json:
        return -3, negative_rank_list, {}
    # parser query to string format
    query = json.loads(query)['query']
    if '%%' in query:
        has_multi = True
        queries = query.split('%%')
        for item in queries:
            if item.strip() == "" or '%' in item:
                has_empty = True
    else: # Directly use query as queries
        queries = [query]
    if len(queries) > 6 or has_empty: # Too many or contains empty
        return -1, negative_rank_list, {}
    # Above we filtered invalid samples; now directly compute scores

    # Add original query for parallel retrieval, but pass separately
    if add_original_query:
        queries.append(original_query)

    # Retrieve documents for each query in parallel
    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
        results = list(executor.map(
            lambda query: retriver_items(query, top_k=top_k, threads=32), 
            queries
        ))    


    # Split and pass separately, keep list
    if add_original_query:
        original_query_results = [results[-1]]
        results = results[:-1]
        # results = results # Modified here to test merged original query results
    else:
        original_query = ""
        original_query_results = []
    
    flag = results_is_empty(results)
    if flag: # There exists empty results
        return 0, -1, {}


    try:
        score, rank, pred_results, more_returns = calculate_score_foreach_combination(results, target, top_k, queries, original_query, original_query_results)
    except IndexError:
        # Get complete error stack information
        error_msg = traceback.format_exc()
        print(f"Index error occurred, details:\n{error_msg}")
    if do_print:
        if 'question' in extra_info:
            print("Original query: ", extra_info['question'])
        print("Rewrite result: ", query)
        # print("Rewrite result: ", queries)
        print(f"target: {target}")
        print(f"results: {pred_results[:10]}")
        print(f"answer_score: {score}")
        print(f"rank: {rank}")
        print(f"Current generated query count and whether it's maximum: {more_returns['is_best_rank']}")
        print('\n\n')
    return score, rank, more_returns

def compute_score(data_source, solution_str, ground_truth, extra_info=None):

    total_score = 0
    do_print = random.randint(1, 256) == 1
    top_k = 100
    # split = extra_info['split'] # train or test
    # if split == 'test':
    #     top_k = 10
    scores = [1 for i in range(len(ground_truth))] # This parameter is actually unused; it represents per-document weight, but all data use equal weight 1

    negative_rank_list = [-1 for i in range(len(ground_truth))] # Negative sample: all target ranks are -1

    label = ground_truth # gt ['32331'] document id
    answer_text = solution_str # Directly treat output as the question
    # First check for think: if <think> exists it's think mode; penalize if </think> is missing
    # If it is think data, validate think format
    if 'think' in data_source:
        format_is_valid = validate_response_structure(answer_text, do_print)
        if not format_is_valid:
            return {
                "score": -3.0,
                "acc": 0,
                "pred": answer_text,
                "rank": [-1, -1],
                "more_returns": more_returns
            }
    # Otherwise proceed directly.
    total_score, rank, more_returns = calculate_answer_score_new(answer_text, label, scores, top_k, extra_info, do_print)
    acc = 1 if rank != negative_rank_list else 0
    return {
        "score": total_score,
        "acc": acc,
        "pred": answer_text,
        "rank": rank,
        "more_returns": more_returns
    }

if __name__ == '__main__':
    solution_str = """<|im_start|>assistant:  <answer>{"query": "Microstructural development of human"}</answer>"""
    solution_str = """{"query": "Is rugby league player Alan Davies married"}"""
    ground_truth = {'target': '4983'}
    ground_truth = [13372305]
    scores = compute_score("", solution_str, ground_truth)
    print(scores)