import os
import json
from typing import Dict, List, Optional, Tuple
from itertools import product
import asyncio
from tqdm import tqdm
from datetime import datetime   
from utils import *
from retrieval_base import retrieval_base
from retrieval_benchmark import retrieval_benchmark_base

class Heuristic_Retrieval(retrieval_base):
    '''
    The refinement times ideally should be set to the length of trajectory -1, since we count the readme as the first link in the trajectory and it would essentially be the 0 iteration. 
    
    For example, the trajectory of a repo is [https://github.com/ValveSoftware/GameNetworkingSockets, https://github.com/ValveSoftware/GameNetworkingSockets/blob/master/BUILDING.md], and then the refinement times should be 1.
    '''
    def __init__(self, repo_dir, repo_full_name, refine_times = 1):
        super().__init__(repo_dir, repo_full_name)
        self.refine_times = refine_times
        self.trajectories_data = []
        self.refine_links()

    def predict_target_link(self, selected_trajectory=None):
        '''
        Since our design essentially would traverse through all the links that are found in all iterations, we should consider the prediction to be successful if the target link is within the list. However, here we would just predict the target link by asking the model to predict the target link given the readme content and the list of urls.
        '''
        if selected_trajectory is not None and selected_trajectory !=[]:
            if type(selected_trajectory) != []:
                raise TypeError("selected_trajectory should be a list", selected_trajectory)
            return selected_trajectory[-1]
        else:
            response = openai_client.chat.completions.create(
                model = MODEL_NAME,
                messages = [
                    {'role': 'system', 'content': "Given the readme of the repository and a list of urls, predict the target link. Your output should only be a single url and nothing else."},
                    {'role': 'user', 'content': f"The readme content of the repository is: {self.readme_content}.\n The list of urls are: {self.all_links}."},
                ]
            )
            output = response.choices[0].message.content.strip()
            return output
    
    def predict_retrieval_trajectory(self):
        return self.get_trajectory_from_trajectory_data()
        
    def refine_links(self) -> List[Dict[str, str]]:
        readme_summary = self.readme_content

        refine = False

        # Clear or (re)init) the trajectories on each refine_links call
        self.trajectories_data = []

        for i in range(self.refine_times):
            print("Refining the links for the", i+1, "time")
            if i >= 1:
                refine = True

            # Step 1: Summarize the readme text to get instructions and links
            structured_readme = summarize_text(
                text=readme_summary,
                refine=refine,
                response_format=Extract_Build_Information_and_Links,
                readme_path=self.readme_path
            )

            print("structured_readme_content_and_links", structured_readme)
            readme_summary = structured_readme.Build_Instructions
            external_links = structured_readme.External_URLs
            internal_links = structured_readme.Internal_Paths

            link_dict = {
                'external': external_links,
                'internal': internal_links,
            }

            # We'll keep a record of all visits in this iteration
            iteration_trajectory = {
                "iteration": i + 1,
                "visits": []
            }

            # Step 2: Summarize each external link
            for ext_link in external_links:
                try:
                    structured_ext_summary = asyncio.run(summarize_link(ext_link))
                    ext_build_instructions = structured_ext_summary.Build_Instructions
                    readme_summary += (
                        f"External link: {ext_link}\n\n "
                        f"Extracted information: {ext_build_instructions}\n\n "
                        f"Additional_External_Links: {structured_ext_summary.External_URLs}\n\n"
                    )

                    # Record the “visit” for external link
                    iteration_trajectory["visits"].append({
                        "link_type": "external",
                        "link": ext_link,
                        "summary": ext_build_instructions,
                        "additional_external_links": structured_ext_summary.External_URLs
                    })
                except Exception as e:
                    print(f"Error summarizing external link {ext_link}: {e}")
                    continue

            # Step 3: Summarize each internal link
            for int_link in internal_links:
                try:
                    structured_int_summary = summarize_text(
                        text=read_file(int_link),
                        response_format=Extract_Build_Information_and_Links
                    )
                    int_build_instructions = structured_int_summary.Build_Instructions
                    readme_summary += (
                        f"Internal link: {int_link}\n "
                        f"Extracted information: {int_build_instructions}\n\n "
                        f"Additional_External_Links: {structured_int_summary.External_URLs}. "
                        f"Additional_Internal_Links: {structured_int_summary.Internal_Paths}\n\n"
                    )

                    # Record the “visit” for internal link
                    iteration_trajectory["visits"].append({
                        "link_type": "internal",
                        "link": int_link,
                        "summary": int_build_instructions,
                        "additional_external_links": structured_int_summary.External_URLs,
                        "additional_internal_links": structured_int_summary.Internal_Paths
                    })
                except Exception as e:
                    print(f"Error summarizing internal link {int_link}: {e}")
                    continue

            # Add the iteration's visited links to the full class-level trajectories
            self.trajectories_data.append(iteration_trajectory)

            print("Summarized content after refining the links for the", i+1, "time")
            print("*"*50)
            print(readme_summary)
            print("*"*50)


        return self.trajectories_data
    
    def get_trajectory_from_trajectory_data(self):
        '''
        Since our design essentially would traverse through all the links that are found in the previous iteration, here we record the visited urls in each iteration. Then we would combine the links from each iteration by taking one links from each iteration and also preserving the order of the iterations.
        '''
        trajectories = []
        for refinement_data in self.trajectories_data:
            links_visited_in_iteration = []
            for visit in refinement_data['visits']:
                links_visited_in_iteration.append(visit['link'])
            trajectories.append(links_visited_in_iteration)
        
        self.all_links = [item for sublist in trajectories for item in sublist]
        self.trajectories = list(product(*trajectories))
        return self.trajectories


class Heuristic_Retrieval_Benchmark(retrieval_benchmark_base):
    def __init__(self, input_raw_data_path, output_benchmark_path, cloned_repos_dir, output_retrieval_results_file_path, pre_computed_benchmark_file_path=None, pre_computed_retrieval_results_path=None, refine_times = 3, **kwargs):
        self.refine_times = refine_times        
        super().__init__(input_raw_data_path, output_benchmark_path, cloned_repos_dir, pre_computed_benchmark_file_path, output_retrieval_results_file_path, pre_computed_retrieval_results_path)
        
    
    def evaluate_trajectory(self, index, predicted_trajectories):
        """
        Evaluate the predicted trajectories against the ground truth trajectory for a given repository.

        Parameters:
        - index (int): The index of the repository in the retrieval benchmark dataset.
        - predicted_trajectories (List[List[str]]): A list of predicted trajectories, where each trajectory is a list of URLs.

        Returns:
        - List[str]: The selected trajectory, which is either the exact match with the ground truth trajectory or the one with the highest coverage.
        """
        ground_truth_trajectory = self.get_ground_truth_trajectory(index) 

        ### NOTE: This loop is for Heuristic Retrieval only, since Heuristic Retrieval would return all trajectories, as we assume LLM has seen all the links in the trajectory.
        trajectory_coverage_list = []
        len_ground_truth_trajectory = len(ground_truth_trajectory)
        
        for trajectory in predicted_trajectories:
            if trajectory == ground_truth_trajectory:
                # If the trajectory is the same as the ground truth trajectory, return immediately
                self.trajectory_accuracy += 1
                self.trajectory_coverage += 1
                self.trajectory_length += len(trajectory)
                return trajectory
            else:
                # Otherwise, calculate the coverage of the predicted trajectory against ground truth trajectory
                temp_coverage = self.calculate_trajectory_coverage(predicted_trajectory=trajectory, ground_truth_trajectory=ground_truth_trajectory)
                trajectory_coverage_list.append(temp_coverage)
            
        # If no trajectory is the same as the ground truth trajectory, return the trajectory with the highest coverage
        # The coverage will not be added to the accuracy but purely for selection usage, since we consider that partial trajectory provides 0 useful information for our task
        if trajectory_coverage_list:
            max_coverage = max(trajectory_coverage_list)
            max_coverage_index = trajectory_coverage_list.index(max_coverage)
            selected_predicted_trajectory = predicted_trajectories[max_coverage_index]
            self.trajectory_length += len(selected_predicted_trajectory)
            self.trajectory_coverage += max_coverage
        else:
            # Fallback: return an empty trajectory or handle as needed
            selected_predicted_trajectory = []
        return selected_predicted_trajectory
    
    def evaluate_target_link(self, index, predicted_target_link):
        return super().evaluate_target_link(index, predicted_target_link)
        
    def generate_single_retrieval_result(self, index):
        try:
            benchmark_data = self.get_item(index)
            repo_dir = benchmark_data['repo_dir']
            repo_full_name = f"{benchmark_data['repo_url'].split('/')[-2]}/{benchmark_data['repo_url'].split('/')[-1]}"
            retrieval_class = Heuristic_Retrieval(repo_dir, repo_full_name, refine_times=self.refine_times)
            target_trajectories = retrieval_class.predict_retrieval_trajectory()
            # First we evaluate the trajectory, and get the selected predicted trajectory for the target link prediction
            selected_predicted_trajectory = self.evaluate_trajectory(index, target_trajectories)
            predicted_target_link = retrieval_class.predict_target_link(selected_predicted_trajectory)
            self.evaluate_target_link(index, predicted_target_link)
            
            retrieval_results = {  
                "repo_name": benchmark_data['repo_name'],
                "repo_dir": benchmark_data['repo_dir'],
                "ground_truth_trajectory": benchmark_data['retrieval_trajectory'],
                "predicted_trajectory": selected_predicted_trajectory,            "ground_truth_target_link": benchmark_data['retrieval_target_link'],
                "predicted_target_link": predicted_target_link,
                "refinement_process": retrieval_class.trajectories_data,
            }
            individual_retrieval_directory = os.path.dirname(self.output_retrieval_results_file_path)
            timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
            individual_retrieval_file_path = os.path.join(individual_retrieval_directory, f"{benchmark_data['repo_name']}_retrieval_results_{timestamp}.json")
            with open(individual_retrieval_file_path, 'w') as f:
                json.dump(retrieval_results, f, indent=4)
            print(f"Individual retrieval results saved to {individual_retrieval_file_path}")
            return retrieval_results

        except Exception as e:
            print(f"Error generating retrieval result for index {index}: {e}")
            return {}
        


    
## For Testing
if __name__ == '__main__':
    # repo_dir = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/cloned_repos/catboost'
    # data = Heuristic_Retrieval(repo_dir, refine_times=3)
    # data.refine_links()
    # print(data.predict_retrieval_trajectory())
    
    with open("/mnt/midnight/steven_zhang/LLM_assisted_compilation/Compilation_Benchmark/data/temp.json", "r") as f:
        data = json.load(f)
        trajectories = []
        for refinement_data in data:
            links_visited_in_iteration = []
            for visit in refinement_data['visits']:
                links_visited_in_iteration.append(visit['link'])
            trajectories.append(links_visited_in_iteration)
            
        # combinations = [list(p) for p in product(*trajectories)]
        # from pprint import pprint
        # pprint(combinations)
        all_links = [item for sublist in trajectories for item in sublist]
        print(all_links)
        with open("/mnt/midnight/steven_zhang/LLM_assisted_compilation/cloned_repos/catboost/README.md", 'r') as f:
            catboost_readme = f.read()
        
        response = openai_client.chat.completions.create(
            model = MODEL_NAME,
            messages = [
                {'role': 'system', 'content': "Given the readme of the repository and a list of urls, predict the target link. Your output should only be a single url and nothing else."},
                {'role': 'user', 'content': f"The readme content of the repository is: {catboost_readme}.\n The list of urls are: {all_links}."},
            ]
        )
        output = response.choices[0].message.content.strip()
        print(output)