import os, sys, json, copy, random, time
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'paper_decomposition'))
from common_utils import init_llm_client, llm_generation, extract_field
from prompt_store import instruction_prompts
from multiprocessing import Pool, cpu_count
from functools import partial


# Output:
#   MDP_roads: [[insp_id, hyp_step], ...] where hyp_step is the delta hypothesis for that inspiration
def sample_one_MDP_for_one_paper_from_hypothesis_components(inspirations, hypothesis_components, paper_name):
    """
    Build MDP road from hypothesis_components (v2 format).
    
    Always follows sequential order (0->1->2->...) since components are indexed.
    
    hypothesis_components: {"0": delta_0, "1": delta_1, ...}
    Returns: [[0, delta_0], [1, delta_1], ...] in sequential order
    """
    n = len(inspirations)
    assert n >= 1, f"There should be at least one inspiration: {paper_name}"
    assert n == len(hypothesis_components), \
        f"Mismatch: {n} inspirations vs {len(hypothesis_components)} hypothesis_components: {paper_name}"
    
    # Always use sequential order: 0 -> 1 -> 2 -> ...
    MDP_road = []
    for i in range(n):
        delta_hyp = hypothesis_components.get(str(i), "")
        MDP_road.append([i, delta_hyp])
    
    return MDP_road


# Output:
#   MDP_roads: [[insp_id, hyp_step1], [insp_id, hyp_step2], ...]
def sample_one_MDP_for_one_paper_from_road_collection(inspirations, road_collection, final_hyp, paper_name):
    len_inspirations = len(inspirations)
    assert len_inspirations >= 1, f"There should be at least one inspiration for each paper: {paper_name}, {inspirations}"
    if len_inspirations == 1:
        return [[0, final_hyp]]
    else:
        # first try the classic MDP road (0->1->2); check whether such road exists
        if_classic_MDP_road_exists = True
        classic_MDP_road = []
        next_hyp = final_hyp
        tmp_road_collection = copy.deepcopy(road_collection)
        # here i represents the current inspiration index
        for i in range(len_inspirations-1, -1, -1):  # This will give len_inspirations-1, len_inspirations-2, ..., 1, 0
            try:
                classic_MDP_road.insert(0, [i, next_hyp])
                if i == 0:
                    # print("The classic MDP road exists.")
                    break
                next_hyp = tmp_road_collection[str(i)][0]
                tmp_road_collection = tmp_road_collection[str(i)][1]
            except Exception as e:
                if_classic_MDP_road_exists = False
                # print(f"The classic MDP road does not exist: {paper_name}; Error: {e}")
                break
        if if_classic_MDP_road_exists:
            return classic_MDP_road
        else:
            # just grab a random MDP road
            if_random_MDP_road_exists = True
            random_MDP_road = []
            tmp_road_collection = copy.deepcopy(road_collection)
            next_hyp = final_hyp
            collected_keys = []
            # here i is just a countdown, not the inspiration index
            for i in range(len_inspirations-1, -1, -1):
                # print(f"i: {i}")
                # obtain past_keys
                past_keys = list(tmp_road_collection.keys())
                # print(f"past_keys: {past_keys}")
                # print(f"collected_keys: {collected_keys}")
                if len(past_keys) == 0:
                    # print(f"len_inspirations: {len_inspirations}; len(collected_keys): {len(collected_keys)}")
                    if len_inspirations == len(collected_keys) + 1:
                        # the last inspiration, we should know its insp key by all keys - collected_keys
                        all_keys = list(set(range(len_inspirations)))
                        all_keys = [str(key) for key in all_keys]
                        past_keys = list(set(all_keys) - set(collected_keys))
                        assert len(past_keys) == 1, f"There should be only one key left: {paper_name}, {past_keys}"
                    else:
                        if_random_MDP_road_exists = False
                        print(f"The random MDP road does not exist: {paper_name}")
                        break
                # keep going with past_keys
                random_past_key = random.choice(past_keys)
                random_MDP_road.insert(0, [int(random_past_key), next_hyp])
                collected_keys.append(random_past_key)
                if i == 0:
                    # print(f"The random MDP road exists: {paper_name}")
                    break
                next_hyp = tmp_road_collection[str(random_past_key)][0]
                tmp_road_collection = tmp_road_collection[str(random_past_key)][1]
            assert if_random_MDP_road_exists, f"The random MDP road does not exist: {paper_name}, {random_MDP_road}"
            return random_MDP_road


class HCReasoningTrace:

    def __init__(self, api_type, api_key, base_url, model_name, sft_qa_data_dir, sft_HC_reasoning_trace_dir, data_format):
        """
        Args:
            data_format: "v1" for road_collection format, "v2" for hypothesis_components format
        """
        self.api_type = api_type
        self.api_key = api_key
        self.base_url = base_url
        self.model_name = model_name
        self.data_format = data_format
        
        assert data_format in ["v1", "v2"], f"data_format must be 'v1' or 'v2', got: {data_format}"

        self.client = init_llm_client(api_type, api_key, base_url)

        self.sft_qa_data_dir = sft_qa_data_dir
        self.sft_HC_reasoning_trace_dir = sft_HC_reasoning_trace_dir
        os.makedirs(self.sft_HC_reasoning_trace_dir, exist_ok=True)

        # Find all unprocessed papers
        self.unprocessed_paper_list = self.find_unprocessed_paper()

    def find_unprocessed_paper(self):
        """
        Find all unprocessed papers by comparing SFT QA data files with reasoning trace files.
        
        Returns all papers that have SFT QA data but no reasoning trace yet,
        sorted for reproducible processing order.
        """
        # Get all SFT QA data files sorted chronologically for reproducibility
        all_sft_qa_files = []
        for sft_qa_file in os.listdir(self.sft_qa_data_dir):
            if sft_qa_file.endswith(".json"):
                all_sft_qa_files.append(sft_qa_file)
        
        # Sort by year (handling special case 0000 as 2020), then by PMID
        def get_sort_key(filename):
            year_str = filename.split('_')[0]
            year = int(year_str) if year_str != '0000' else 2020  # Special case: 0000 means 2020
            pmid = filename.split('_')[1].split('.')[0] if '_' in filename else ''
            return (year, pmid)
        
        all_sft_qa_files.sort(key=get_sort_key)
        
        # Build a set of processed files from reasoning trace directory for quick lookup
        # Now we compare full filenames (year_pmid.json) instead of just PMIDs
        processed_files = set()
        for reasoning_trace_file in os.listdir(self.sft_HC_reasoning_trace_dir):
            if reasoning_trace_file.endswith(".json"):
                processed_files.add(reasoning_trace_file)
        
        # Find all unprocessed papers (those without reasoning trace files)
        unprocessed_paper_list = []
        for sft_qa_file in all_sft_qa_files:
            # Check if this exact filename exists in the processed set
            if sft_qa_file not in processed_files:
                unprocessed_paper_list.append(sft_qa_file)
        
        # Print statistics
        total_sft_qa = len(all_sft_qa_files)
        total_processed = len(processed_files)
        total_unprocessed = len(unprocessed_paper_list)
        
        print(f"Total SFT QA data files: {total_sft_qa}")
        print(f"Total processed (with reasoning trace): {total_processed}")
        print(f"Total unprocessed (need reasoning trace): {total_unprocessed}")
        
        if total_unprocessed == 0:
            print("All papers have been processed.")
        else:
            print(f"First unprocessed paper: {unprocessed_paper_list[0] if unprocessed_paper_list else 'None'}")
            if total_unprocessed > 1:
                print(f"Last unprocessed paper: {unprocessed_paper_list[-1]}")
        
        return unprocessed_paper_list


    def generate_reasoning_trace_all_paper(self, num_workers=None):
        """Process all papers in parallel."""
        if not self.unprocessed_paper_list:
            print("No papers to process.")
            return
        
        # Use CPU count if not specified
        if num_workers is None:
            num_workers = cpu_count()
        
        total = len(self.unprocessed_paper_list)
        print(f"Processing {total} papers using {num_workers} workers")
        
        # Create worker function with fixed parameters
        process_func = partial(
            self._process_paper_worker,
            api_type=self.api_type,
            api_key=self.api_key,
            base_url=self.base_url,
            model_name=self.model_name,
            sft_qa_data_dir=self.sft_qa_data_dir,
            sft_HC_reasoning_trace_dir=self.sft_HC_reasoning_trace_dir,
            data_format=self.data_format
        )
        
        start_time = time.time()
        completed = 0
        
        # Process papers in parallel
        with Pool(processes=num_workers) as pool:
            for paper_name in pool.imap_unordered(process_func, self.unprocessed_paper_list):
                completed += 1
                remaining = total - completed
                elapsed = time.time() - start_time
                avg_time = elapsed / completed  # Wall-clock time per completion (already includes parallelism)
                eta = avg_time * remaining  # No division by num_workers - avg_time already reflects throughput
                
                print(f"[{completed}/{total}] Processed: {paper_name} | "
                      f"Avg: {avg_time:.1f}s/paper | ETA: {eta/60:.1f}min")
        
        total_time = time.time() - start_time
        print(f"\n=== Completed ===")
        print(f"Total papers: {total}")
        print(f"Total time: {total_time/60:.1f} minutes")
        print(f"Average: {total_time/total:.1f}s/paper")
    
    @staticmethod
    def _process_paper_worker(paper_name, api_type, api_key, base_url, model_name, 
                              sft_qa_data_dir, sft_HC_reasoning_trace_dir, data_format):
        """Worker function to process a single paper in parallel."""
        try:
            # Import here to avoid issues with multiprocessing
            from common_utils import init_llm_client
            
            # Create a minimal instance without calling find_unprocessed_paper
            worker = object.__new__(HCReasoningTrace)
            worker.api_type = api_type
            worker.api_key = api_key
            worker.base_url = base_url
            worker.model_name = model_name
            worker.data_format = data_format
            worker.client = init_llm_client(api_type, api_key, base_url)
            worker.sft_qa_data_dir = sft_qa_data_dir
            worker.sft_HC_reasoning_trace_dir = sft_HC_reasoning_trace_dir
            
            # Process the paper
            worker.generate_reasoning_trace_per_paper(paper_name)
            return paper_name
            
        except Exception as e:
            print(f"Error processing {paper_name}: {e}")
            import traceback
            traceback.print_exc()
            return paper_name  # Return paper name even on error for progress tracking

    # Output:
    #   MDP_road_with_reasoning_trace: [[insp_id, prev_hyp, found_title, found_abstract, next_hyp, reasoning_trace, hypothesis_label], ...]
    def generate_reasoning_trace_per_paper(self, paper_name):
        sft_qa_data_file_path = os.path.join(self.sft_qa_data_dir, paper_name)
        with open(sft_qa_data_file_path, "r") as f:
            sft_qa_data = json.load(f)
        research_question = sft_qa_data["research_question"]
        survey = sft_qa_data["background_survey"]
        hypothesis = sft_qa_data["fine_grained_hypothesis"]
        inspirations = sft_qa_data["inspiration"]

        # Use data_format hyperparameter to determine processing logic
        is_v2_format = (self.data_format == "v2")
        if is_v2_format:
            # v2 format: use delta hypotheses directly (sequential order: 0->1->2->...)
            hypothesis_components = sft_qa_data["hypothesis_components"]
            MDP_road = sample_one_MDP_for_one_paper_from_hypothesis_components(inspirations, hypothesis_components, paper_name)
        else:
            # v1 format: use road_collection to traverse intermediate hypotheses (prefers sequential order)
            road_collection = sft_qa_data["road_collection"]
            MDP_road = sample_one_MDP_for_one_paper_from_road_collection(inspirations, road_collection, hypothesis, paper_name)
        # print(f"MDP_road: {MDP_road}")
        MDP_road_with_reasoning_trace = []
        for i, (insp_id, next_hyp) in enumerate(MDP_road):
            cur_insp = inspirations[insp_id]
            found_title = cur_insp["found_title"]
            found_abstract = cur_insp["found_abstract"]
            
            if is_v2_format:
                # v2: prev_hyp is cumulative (join all previous deltas), next_hyp is delta
                if i > 0:
                    prev_deltas = [MDP_road[j][1] for j in range(i)]
                    prev_hyp = "\n\n".join(prev_deltas)
                else:
                    prev_hyp = None
            else:
                # v1: prev_hyp is previous cumulative state from road_collection
                prev_hyp = MDP_road[i-1][1] if i > 0 else None
            
            reasoning_trace, hypothesis_label = self.generate_reasoning_trace_per_step(
                research_question, survey, prev_hyp, cur_insp, next_hyp, is_v2_format=is_v2_format
            )
            MDP_road_with_reasoning_trace.append([insp_id, prev_hyp, found_title, found_abstract, next_hyp, reasoning_trace, hypothesis_label])
        
        # save the result with the same filename as input
        with open(os.path.join(self.sft_HC_reasoning_trace_dir, paper_name), "w") as f:
            json.dump(MDP_road_with_reasoning_trace, f, indent=2)
        
        return MDP_road_with_reasoning_trace


    # Output:
    #   reasoning_trace: the reasoning trace for the current step
    #   hypothesis_label: the hypothesis label for the current step
    def generate_reasoning_trace_per_step(self, research_question, survey, prev_hyp, cur_insp, next_hyp, is_v2_format=False):
        if prev_hyp == None:
            prev_hyp = "No previous hypothesis."
        # load the inspiration information
        insp_title = cur_insp["found_title"]
        insp_abstract = cur_insp["found_abstract"]
        groundtruth_insp = cur_insp["insp"]
        groundtruth_insp_relation = cur_insp["relation"]

        if is_v2_format:
            # v2: Use delta hypothesis prompt, LLM generates hypothesis with novel names filtered
            prompts = instruction_prompts("generate_reasoning_trace_per_step_v2_delta")
            extraction_field = "Delta Hypothesis"
        else:
            # v1: Use original prompt
            prompts = instruction_prompts("generate_reasoning_trace_per_step_updated_recall")
            extraction_field = "Hypothesis"
        
        full_prompt = prompts[0] + research_question + prompts[1] + survey + prompts[2] + prev_hyp + prompts[3] + insp_title + prompts[4] + insp_abstract + prompts[5] + groundtruth_insp + prompts[6] + groundtruth_insp_relation + prompts[7] + next_hyp + prompts[8]
        
        # Simple retry: try up to 5 times if extraction fails
        max_retries = 30
        for attempt in range(max_retries):
            generation = llm_generation(full_prompt, self.model_name, self.client, temperature=0.1, api_type=self.api_type)
            # Use strict extraction to avoid contamination in extracted fields
            reasoning_trace = extract_field(generation, "Simulated Reasoning Trace", expected_type='text', strict_extraction=True)
            hypothesis_label = extract_field(generation, extraction_field, expected_type='text', strict_extraction=True)
            
            # If both extracted successfully, return
            # For both v1 and v2, use LLM-extracted hypothesis_label (LLM handles novel name filtering via prompt)
            if reasoning_trace is not None and hypothesis_label is not None:
                return reasoning_trace, hypothesis_label
            
            print(f"Extraction failed, retrying... {attempt + 1}/{max_retries}")
        
        print(f"Extraction failed after {max_retries} attempts.")
        return reasoning_trace, hypothesis_label





if __name__ == "__main__":
    api_type = 0
    api_key = "<YOUR_API_KEY>"
    base_url = "<YOUR_API_URL>"
    model_name = "R1-Distill-Qwen-32B"

    # "run8"
    # "2025_October"
    output_dir_postfix = "2025_October"
    # Data format: "v1" for road_collection, "v2" for hypothesis_components
    data_format = "v2"
    
    # MODIFY THESE PATHS
    folder_sft_qa_data = "folder_sft_qa_data"  # Replace with your path
    folder_sft_hc_reasoning_trace = "folder_sft_hc_reasoning_trace"  # Replace with your path
    
    sft_qa_data_dir = folder_sft_qa_data + "/pubmed_sft_qa_data_v2_" + output_dir_postfix
    output_reasoning_trace_dir = folder_sft_hc_reasoning_trace + "/pubmed_sft_HC_reasoning_trace_v2_" + output_dir_postfix

    hc_reasoning_trace = HCReasoningTrace(api_type, api_key, base_url, model_name, sft_qa_data_dir, output_reasoning_trace_dir, data_format=data_format)
    hc_reasoning_trace.generate_reasoning_trace_all_paper(num_workers=5)
    # hc_reasoning_trace.generate_reasoning_trace_per_paper("29428771.json")
    