from pathlib import Path
import pandas as pd
import os
import json
import numpy as np
from sklearn.preprocessing import LabelEncoder
from dateutil import parser

WIKIDB_DIR = Path("/data/parietal/store3/soda/felefebv/gitlab/tarte/data/wikidb")
VALIDATION_DIR = Path(__file__).parent

sizes = pd.read_parquet(
    "/data/parietal/store3/soda/felefebv/gitlab/tarte/tarte/wikidb_sizes.parquet"
)

# Select relevant tasks
classification_targets = {
    "47746_SurnameDetails": "LanguageOfOrigin",
    "73376_HISTORICAL_FIGURES": "PROFESSION",
    "42562_Geographer_Profiles": "Languages",
    "30417_ArtworksCatalog": "ArtworkType",
    "36100_MagicNarrativeMotifs": "CulturalOrigin",
    "87283_ParishChurchDetails": "Country",
    "70942_NotableTreesInformation": "TreeSpecies",
    "29832_Rafael_Individuals": "Nationality",
    "2053_StriatumScientificArticles": "JournalName",
    "7136_researcher_profile": "affiliated_institution",
    "7310_DecommissionedTransportStations": "Country",
    "97297_MusicAlbumsPublishedInUs": "MusicGenre",
    "92415_island_details": "country_name",
    "67195_SUB_POST_OFFICE_DETAILS": "ADMINISTRATIVE_TERRITORY",
    "9510_CreativeCommonsAuthors": "Gender",
    "66643_KindergartenLocations": "Country",
    "64477_NobleIndividuals": "Role",
    "56474_Sculpture_Instances": "Material_Used",
    "90741_MUSEUM_DETAILS": "COUNTRY",
    "473_HistoricBuildings": "CountryName",
    "97229_PhilosopherProfiles": "Languages",
    "7900_ArtistCopyrightRepresentation": "ArtistOccupation",
    "63797_SpringLocations": "CountryName",
    "65102_defender_profiles": "defender_position",
    "15542_FORWARD_PLAYERS": "SPORTS_TEAM",
    "70780_StateSchoolDetails": "Country",
}

regression_targets = {
    "90930_RegisteredShips": "GrossTonnage",
    "19664_MunicipalDistrictCapitals": "PopulationCount",
    "66610_geopolitical_regions": "land_area",
    "89039_Business_Entity_Locations": "Population_Count",
    "53353_research_articles": "publication_date",
    "14012_ResearchArticleCitations": "PublicationDate",
    "14976_DrawingsCatalog": "ArtworkHeightCm",
    "88197_artworks_inventory": "artwork_width_cm",
    "3977_Eclipsing_Binary_Star_Instances": "Apparent_Magnitude",
    "62826_HISTORICAL_FIGURES": "BIRTH_DATE",
    "46159_DissolvedMunicipalityRecords": "DissolutionDate",
    "28324_ukrainian_village_instances": "elevation_meters",
    "94062_POET_PROFILES": "DEATH_DATE",
    "82939_Territorial_Entities": "Population_Count",
    "89439_WwiPersonnelProfiles": "BirthDate",
    "28146_Twinned_Cities": "Population",
}

date_regressions = [
    "53353_research_articles",
    "14012_ResearchArticleCitations",
    "62826_HISTORICAL_FIGURES",
    "46159_DissolvedMunicipalityRecords",
    "94062_POET_PROFILES",
    "89439_WwiPersonnelProfiles",
]


def date_to_fractional_year(date_str):
    """Function to compute fractional year.
    We do not use pd.to_datetime because it does not support datetimes
    outside the range of Python's standard library.
    """
    try:
        # Parse the date manually
        parsed_date = parser.parse(date_str)
        year = parsed_date.year
        month = parsed_date.month
        day = parsed_date.day

        # Handle negative years
        if date_str.startswith("-"):
            year = -year

        # Days in each month (non-leap year)
        days_in_month = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]

        # Check for leap year and adjust February's days
        if year % 4 == 0 and (year % 100 != 0 or year % 400 == 0):
            days_in_month[1] = 29

        # Calculate day of the year
        day_of_year = sum(days_in_month[: month - 1]) + day

        # Calculate total days in the year
        total_days = sum(days_in_month)

        # Compute fractional year
        fractional_year = year + (day_of_year - 1) / total_days
        return fractional_year
    except Exception as e:
        # Handle invalid dates
        return None


def load_table(table_name, target_column):
    # Load the table
    idx = int(table_name.split("_")[0])
    folder_group = f"part-{idx // 20000}"
    folder_list = os.listdir(WIKIDB_DIR / folder_group)
    prefix = str(idx).zfill(5)
    folder = [f for f in folder_list if f.startswith(prefix)][0]
    table = "_".join(table_name.split("_")[1:]) + ".csv"
    df = pd.read_csv(WIKIDB_DIR / folder_group / folder / "tables" / table)
    df_ids = pd.read_csv(
        WIKIDB_DIR / folder_group / folder / "tables_with_item_ids" / table
    )

    # Get the target column index
    target_idx = df.columns.get_loc(target_column)

    # Keep only the first column and the target column
    df = df.iloc[:, [0, target_idx]]

    # Add the wikidata id column
    df["wikidata_col_to_embed"] = df_ids.iloc[:, 0]

    # Rename the columns
    df.rename(
        columns={
            df.columns[0]: "raw_entities",
            df.columns[1]: "target",
        },
        inplace=True,
    )

    return df


def balanced_sample(
    df, class_column="target", target_size=3000, min_per_class=50, random_state=0
):
    np.random.seed(random_state)

    class_counts = df[class_column].value_counts()
    classes = class_counts.index.tolist()
    num_classes = len(classes)

    # Determine the number of samples per class
    if num_classes * min_per_class <= target_size:
        # Enough space to ensure at least min_per_class per class
        per_class = {cls: min(min_per_class, class_counts[cls]) for cls in classes}

        # Distribute remaining samples proportionally
        remaining = target_size - sum(per_class.values())

        if remaining > 0:
            available = {cls: class_counts[cls] - per_class[cls] for cls in classes}
            available_total = sum(available.values())
            for cls in classes:
                if available_total > 0 and available[cls] > 0:
                    extra = int((available[cls] / available_total) * remaining)
                    per_class[cls] += min(extra, available[cls])

    else:
        # Not enough room, so sample as evenly as possible
        per_class = {
            cls: min(class_counts[cls], target_size // num_classes) for cls in classes
        }

    # Now sample
    sampled_dfs = []
    for cls in classes:
        sampled_cls = df[df[class_column] == cls].sample(
            n=per_class[cls], random_state=random_state
        )
        sampled_dfs.append(sampled_cls)

    sampled_df = pd.concat(sampled_dfs).sample(
        frac=1, random_state=random_state
    )  # Shuffle
    return sampled_df


def main():
    ## Process classification tasks
    print("Processing classification tasks")
    for table_name, target in classification_targets.items():
        print(f"    {table_name}")

        # Load the table
        table = load_table(table_name, target)

        # Remove rows with missing targets
        table = table.dropna(subset=["target"]).reset_index(drop=True)

        # Keep the classes that have more than 50 occurrences, or more than 90% of the second class cardinality
        counts = table["target"].value_counts()
        num_classes = len(counts.index.tolist())
        threshold = min(50, int(0.9 * counts[1]))
        classes_to_remove = counts[counts < threshold].index
        # Keep at most 30 classes
        if num_classes - len(classes_to_remove) > 30:
            classes_to_remove = counts.index[30:]

        table = table[~table["target"].isin(classes_to_remove)].reset_index(drop=True)

        # Reduce the size of the tables that are too large
        if len(table) > 3000:
            # Sample 3000 rows and guarantee that all classes have enough occurrences
            table = balanced_sample(table, target_size=3000, min_per_class=threshold)

        print(f"        Number of rows: {len(table)}")
        print(f"        Number of classes: {len(table['target'].unique())}")
        print(f"        Most populated class: {table['target'].value_counts().max()}")
        print(f"        Least populated class: {table['target'].value_counts().min()}")

        # Encode classification targets
        le = LabelEncoder()
        table["target"] = le.fit_transform(table["target"])
        with open(
            VALIDATION_DIR / f"classification/{table_name}_classes.json", "w"
        ) as f:
            json.dump(le.classes_.tolist(), f)

        # Save the table
        table.to_parquet(
            VALIDATION_DIR / f"classification/{table_name}.parquet", index=False
        )

    ## Process regression tasks
    print("Processing regression tasks")
    for table_name, target in regression_targets.items():
        print(f"    {table_name}")

        # Load the table
        table = load_table(table_name, target)

        # Process date targets
        if table_name in date_regressions:
            table["target"] = table["target"].str.removeprefix("+")
            table["target"] = table["target"].str.replace(
                r"(-?\d{1,4})-00-00", r"\1-01-01", regex=True
            )
            table["target"] = table["target"].apply(date_to_fractional_year)
            table["target"] = table["target"].apply(lambda x: 2025 - x)

        # Rescale targets
        table["target"] = table["target"].astype(float)
        table["target"] = table["target"].apply(np.log10)

        # Remove rows with missing targets
        table = table.dropna(subset=["target"]).reset_index(drop=True)

        # Remove rows with non-finite targets
        table = table[np.isfinite(table["target"])].reset_index(drop=True)

        # Reduce the size of the tables that are too large
        if len(table) > 3000:
            # Sample randomly 3000 rows
            table = table.sample(n=3000, random_state=0).reset_index(drop=True)

        # Save the table
        table.to_parquet(
            VALIDATION_DIR / f"regression/{table_name}.parquet", index=False
        )
    return


if __name__ == "__main__":
    main()
