import os, sys, json, copy, random, time
import pandas as pd

# Input:
#   raw_data_dir: the directory where the xls/xlsx files are stored (downloaded from Web of Science)
# Output:
#   dict_all_ttl_by_year: {year: [[title, abstract, year], ...]}
def load_title_abstract_and_year_from_web_of_science(raw_data_dir):
    print("\n\nLoading title and abstract from Web of Science...")
    files = os.listdir(raw_data_dir)

    all_ttl_abs = []
    for cur_file in files:
        # print("cur_file:", cur_file)
        if not (cur_file.endswith('.xlsx') or cur_file.endswith('.xls')) or cur_file.startswith('.~'):
            print("cur_file is not a xlsx or xls file:", cur_file)
            continue 
        cur_file_full_path = os.path.join(raw_data_dir, cur_file)
        cur_ttl_abs = []
        # print("cur_file_full_path:", cur_file_full_path)
        # Load xls file
        if cur_file.endswith('.xlsx'):
            df = pd.read_excel(cur_file_full_path)
        elif cur_file.endswith('.xls'):
            df = pd.read_excel(cur_file_full_path, engine='xlrd')
        else:
            print(f"Unsupported file format: {cur_file}")
            continue
        nan_values = df.isna()
        cur_titles = df['Article Title'].tolist()
        cur_abstracts = df['Abstract'].tolist()
        cur_year = df['Publication Year'].tolist()
        cur_year = [int(cur_year[cur_id_ttl]) for cur_id_ttl in range(len(cur_year))]
        assert len(cur_titles) == len(cur_abstracts) == len(cur_year), "Title, Abstract, and Year lengths do not match"
        for cur_id_ttl in range(len(cur_titles)):
            if nan_values['Article Title'][cur_id_ttl]:
                # print("nan_values['Article Title'][cur_id_ttl]:", cur_titles[cur_id_ttl])
                continue
            elif nan_values['Abstract'][cur_id_ttl]:
                # print("nan_values['Abstract'][cur_id_ttl]:", cur_abstracts[cur_id_ttl])
                continue
            elif nan_values['Publication Year'][cur_id_ttl]:
                # print("nan_values['Publication Year'][cur_id_ttl]:", cur_year[cur_id_ttl])
                continue
            cur_ttl_abs.append([cur_titles[cur_id_ttl].strip(), cur_abstracts[cur_id_ttl].strip(), cur_year[cur_id_ttl]])
        # print("len(cur_ttl_abs):", len(cur_ttl_abs))
        all_ttl_abs.extend(cur_ttl_abs)
    print("len(all_ttl_abs):", len(all_ttl_abs))
    # get rid of repeated title-abstract pairs
    # all_ttl_abs: list of [title, abstract]
    all_ttl_abs = list(dict.fromkeys(tuple(item) for item in all_ttl_abs))
    all_ttl_abs = [list(item) for item in all_ttl_abs]
    print("len(all_ttl_abs) (no superficial repetition):", len(all_ttl_abs))
    print("all_ttl_abs[0]:", all_ttl_abs[0])
    # dict_all_ttl_by_year: {year: [[title, abstract, year], ...]}
    dict_all_ttl_by_year = {}
    for cur_ttl_abs in all_ttl_abs:
        if cur_ttl_abs[2] not in dict_all_ttl_by_year:
            dict_all_ttl_by_year[cur_ttl_abs[2]] = []
        dict_all_ttl_by_year[cur_ttl_abs[2]].append(cur_ttl_abs)
    # print the key and corresponding length of dict_all_ttl_by_year
    print("================================================")
    for cur_year in dict_all_ttl_by_year.keys():
        print("year: {}, length: {}".format(cur_year, len(dict_all_ttl_by_year[cur_year])))

    return dict_all_ttl_by_year


# Function: the paper to decompose, and the decomposed inspiration papers are neuroscience papers, which can be used as negative inspirations; we know the year of the paper to decompose, and we can assume its decomposed inspiration papers are from the same year (so that it won't cause an issue of retrieving an inspiration from the future)
# Input:
#   sft_qa_data_dirs: a list of directories of the sft qa data, or a single directory string (for backward compatibility)
# Output:
#   dict_all_ttl_by_year: {year: [[title, abstract, year], ...]}
def load_title_abstract_and_year_from_pubmed_decomposition(sft_qa_data_dirs):
    print("\n\nLoading title and abstract from PubMed decomposition...")
    # Handle both single string and list inputs for backward compatibility
    if isinstance(sft_qa_data_dirs, str):
        sft_qa_data_dirs = [sft_qa_data_dirs]
    
    dict_all_ttl_by_year = {}
    
    # Process each directory
    for sft_qa_data_dir in sft_qa_data_dirs:
        files = os.listdir(sft_qa_data_dir)
        for cur_file in files:
            if not cur_file.endswith('.json'):
                continue
            cur_file_full_path = os.path.join(sft_qa_data_dir, cur_file)
            with open(cur_file_full_path, 'r') as f:
                cur_data = json.load(f)
            paper_year = cur_file.split('_')[0]
            assert paper_year in ["0000", "2019", "2020", "2021", "2022", "2023", "2024", "2025", "2026"], "Paper year is not in the expected range: {}".format(paper_year)
            paper_year = "2019" if paper_year == "0000" else paper_year
            paper_year = int(paper_year)
            assert isinstance(paper_year, int), "Paper year is not an integer: {}".format(paper_year)
            if paper_year not in dict_all_ttl_by_year:
                dict_all_ttl_by_year[paper_year] = []
            if paper_year - 1 not in dict_all_ttl_by_year:
                dict_all_ttl_by_year[paper_year - 1] = []
            # add the paper to decompose itself
            dict_all_ttl_by_year[paper_year].append([cur_data["title"], cur_data["abstract"], paper_year])
            # add the decomposed inspiration papers (assume they are from the previous year)
            for cur_inspiration in cur_data["inspiration"]:
                dict_all_ttl_by_year[paper_year - 1].append([cur_inspiration["found_title"], cur_inspiration["found_abstract"], paper_year - 1])
    # print the key and corresponding length of dict_all_ttl_by_year
    print("================================================")
    for cur_year in dict_all_ttl_by_year.keys():
        print("year: {}, length: {}".format(cur_year, len(dict_all_ttl_by_year[cur_year])))
    return dict_all_ttl_by_year


# Function: load the title and abstract from all sources (Web of Science and PubMed decomposition) and merge them into one dictionary
# Input:
#   wos_raw_data_dir: the directory of the Web of Science raw data
#   sft_qa_data_dirs: a list of directories of the sft qa data, or a single directory string (for backward compatibility)
# Output:
#   dict_all_ttl_by_year_all: {year: [[title, abstract, year], ...]}
def load_title_abstract_and_year_from_all_sources(wos_raw_data_dir, sft_qa_data_dirs):
    # load the title and abstract from Web of Science and PubMed decomposition
    dict_all_ttl_by_year_wos = load_title_abstract_and_year_from_web_of_science(wos_raw_data_dir)
    dict_all_ttl_by_year_pubmed = load_title_abstract_and_year_from_pubmed_decomposition(sft_qa_data_dirs)
    # merge the two dictionaries
    dict_all_ttl_by_year_all = copy.deepcopy(dict_all_ttl_by_year_pubmed)
    for cur_year in dict_all_ttl_by_year_wos.keys():
        if cur_year not in dict_all_ttl_by_year_all:
            dict_all_ttl_by_year_all[cur_year] = []
        dict_all_ttl_by_year_all[cur_year].extend(dict_all_ttl_by_year_wos[cur_year])
    # sort the keys of dict_all_ttl_by_year_all
    dict_all_ttl_by_year_all = dict(sorted(dict_all_ttl_by_year_all.items()))
    # print the key and corresponding length of dict_all_ttl_by_year_all
    print("================================================")
    for cur_year in dict_all_ttl_by_year_all.keys():
        print("year: {}, length: {}".format(cur_year, len(dict_all_ttl_by_year_all[cur_year])))
    return dict_all_ttl_by_year_all


# Input:
#   paper_all: {year: [[title, abstract, year], ...]}
#   year: int
#   k: int
#   ban_list: list of [title, abstract] pairs to exclude from sampling (optional)
# Output:
#   sampled_papers: [[title, abstract, year], ...]
def random_sample_k_papers_from_paper_all_before_year(paper_all, year, k, ban_list=None):
    sampled_papers = []
    for cur_year in paper_all.keys():
        # Convert to int for comparison if it's a string
        cur_year_int = int(cur_year) if isinstance(cur_year, str) else cur_year
        if cur_year_int < year:
            sampled_papers.extend(paper_all[cur_year])
    
    # Filter out banned papers if ban_list is provided
    if ban_list:
        filtered_papers = []
        for paper in sampled_papers:
            paper_title = paper[0].lower().strip() if paper[0] else ""
            paper_abstract = paper[1].lower().strip() if paper[1] else ""
            
            # Check if this paper is in the ban list
            is_banned = False
            for banned_title, banned_abstract in ban_list:
                banned_title = banned_title.lower().strip() if banned_title else ""
                banned_abstract = banned_abstract.lower().strip() if banned_abstract else ""
                
                # Match by title or by abstract (if substantial)
                if paper_title and banned_title and paper_title == banned_title:
                    is_banned = True
                    break
                if paper_abstract and banned_abstract and len(paper_abstract) > 50 and paper_abstract == banned_abstract:
                    is_banned = True
                    break
            
            if not is_banned:
                filtered_papers.append(paper)
        
        sampled_papers = filtered_papers
    
    if len(sampled_papers) < k:
        print(f"Warning: Only {len(sampled_papers)} papers available before year {year} (after filtering), but {k} requested")
        # Return all available papers if there are fewer than k
        return sampled_papers
    try:
        sampled_papers = random.sample(sampled_papers, k)
    except ValueError as e:
        print(f"Error sampling papers: {e}")
        print(f"sampled_papers: {len(sampled_papers)}")
        print(f"k: {k}")
        print(f"year: {year}")
        raise e
    return sampled_papers
