import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import argparse
import numpy as np
import pandas as pd
from datasets import load_dataset, Dataset, load_from_disk
from tqdm.auto import tqdm
import ast
import black
from skelo.model.glicko2 import Glicko2Estimator
import warnings

warnings.filterwarnings(action="ignore")

tqdm.pandas()


def format_and_remove_comments(code_str):
    # Format the code string using black
    try:
        formatted_code = black.format_str(code_str, mode=black.FileMode())
    except black.parsing.InvalidInput:
        return ""
    # Parse the formatted code into an AST
    tree = ast.parse(formatted_code)
    # Remove comments from the AST
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
            node.docstring = None
        elif isinstance(node, ast.Expr) and isinstance(node.value, ast.Str):
            node.value.s = ""
    # Convert the modified AST back to source code
    formatted_code_without_comments = ast.unparse(tree)
    return formatted_code_without_comments


def main(dataset_save_path, chunk_index, chunk_size):
    if os.path.exists(os.path.join(dataset_save_path, f"chunk_{chunk_index}")):
        print(f"Chunk {chunk_index} already exists, skipping.")
        return

    pdf = load_dataset(
        "mcding-org/Easy2Hard-Codeforces", "problem-v2", cache_dir="./cache"
    )["train"].to_pandas()

    # Validate and remove
    pdf = pdf[
        (pdf["programmingLanguage_1"] != "N/A")
        & (pdf["programmingLanguage_2"] != "N/A")
    ]
    pdf = pdf[
        (~pdf["source_0"].isna())
        & (~pdf["source_1"].isna())
        & (~pdf["source_2"].isna())
    ]
    pdf = pdf[~pdf["rating"].isna()]
    pdf = pdf[pdf["tags"].apply(lambda x: len(x) > 0)]
    pdf = pdf[~pdf["problem"].apply(lambda x: x["sample-tests"] is None)]
    pdf.drop(
        columns=[
            "points",
            "type",
            "programmingLanguage_0",
            "programmingLanguage_1",
            "programmingLanguage_2",
        ],
        inplace=True,
    )
    pdf.rename(columns={"rating": "reference_rating"}, inplace=True)
    pdf = pdf.sort_values(by=["contestId", "index"]).reset_index(drop=True)

    # Select
    pdf = pdf.iloc[
        chunk_index * chunk_size : min((chunk_index + 1) * chunk_size, len(pdf))
    ]
    if len(pdf) == 0:
        print(f"Chunk {chunk_index} is empty, skipping.")
        return

    # Tags
    sorted_tags = pdf["tags"].explode().value_counts().keys().to_list()

    pdf["tag"] = pdf["tags"].apply(
        lambda tags: [t for t in sorted_tags if t in tags][0]
    )
    pdf["tag"] = pdf["tag"].apply(
        lambda x: "others" if x not in ["greedy", "math", "implementation", "dp"] else x
    )
    pdf["detailed_tag"] = pdf["tags"].apply(
        lambda tags: [t for t in sorted_tags[::-1] if t in tags][0]
    )

    # Explode problem
    pdf["problem_main"] = pdf["problem"].apply(lambda x: x["main"])
    pdf["problem_note"] = pdf["problem"].apply(lambda x: x["note"])
    pdf["input_spec"] = pdf["problem"].apply(lambda x: x["input-specification"])
    pdf["output_spec"] = pdf["problem"].apply(lambda x: x["output-specification"])
    pdf["sample_inputs"] = pdf["problem"].apply(lambda x: x["sample-tests"]["input"])
    pdf["sample_outputs"] = pdf["problem"].apply(lambda x: x["sample-tests"]["output"])
    pdf = pdf[pdf["problem_main"].apply(lambda x: len(x) == 1)]
    pdf["problem_main"] = pdf["problem_main"].apply(lambda x: x[0])
    pdf["problem_note"] = pdf["problem_note"].apply(lambda x: "" if x is None else x)
    pdf = pdf[~pdf["output_spec"].isna()]
    pdf = pdf[
        pdf["sample_inputs"].apply(lambda x: len(x) > 0)
        & pdf["sample_outputs"].apply(lambda x: len(x) > 0)
    ]
    pdf.drop(columns=["problem"], inplace=True)

    rdf = load_from_disk("./data/Codeforces/rating-filtered").to_pandas()
    sdf = load_from_disk("./data/Codeforces/status-filtered").to_pandas()

    # Difficulty
    pdf = pdf.assign(rating=np.nan, rating_deviation=np.nan, rating_volatility=np.nan)
    for contest_id in tqdm(sorted(pdf["contestId"].unique())):
        pdf_contest = pdf[pdf["contestId"] == contest_id]
        pdf_contest.drop(
            columns=["rating", "rating_deviation", "rating_volatility"], inplace=True
        )
        rdf_contest = rdf[rdf["contestId"] == contest_id]
        sdf_contest = sdf[sdf["contestId"] == contest_id]
        rdf_contest = rdf_contest.drop_duplicates(subset="handle", keep="first")

        sdf_contest = sdf_contest[
            sdf_contest["problem"].isin(pdf_contest["index"].unique())
        ]
        sdf_contest = sdf_contest[
            sdf_contest["author"].isin(rdf_contest["handle"].unique())
        ]
        game_records = sdf_contest.copy()
        game_records["problem"] = "problem_" + game_records["problem"]
        game_records["author"] = "author_" + game_records["author"]
        game_records["winner"] = game_records.apply(
            lambda row: row["author"] if row["passed"] else row["problem"], axis=1
        )
        game_records["loser"] = game_records.apply(
            lambda row: row["problem"] if row["passed"] else row["author"], axis=1
        )
        game_records = game_records[["winner", "loser", "timestamp"]]
        glicko2_estimator = Glicko2Estimator(
            key1_field="winner",
            key2_field="loser",
            timestamp_field="timestamp",
            initial_time=game_records["timestamp"].min(),
            initial_value=(1500.0, 350.0, 0.06),
        )
        glicko2_estimator.rating_model = glicko2_estimator.RATING_MODEL_CLS(
            initial_value=(1500.0, 350.0, 0.06),
            initial_time=game_records["timestamp"].min(),
        )

        # Init ratings
        for problem_index, reference_rating in zip(
            pdf_contest["index"], pdf_contest["reference_rating"]
        ):
            glicko2_estimator.rating_model.add(
                key="problem_" + problem_index, value=(reference_rating, 350.0, 0.06)
            )
        for handle, old_rating in zip(rdf_contest["handle"], rdf_contest["oldRating"]):
            glicko2_estimator.rating_model.add(
                key="author_" + handle, value=(old_rating, 350.0, 0.06)
            )

        # Fit
        x = game_records[["winner", "loser", "timestamp"]].values
        y = len(game_records) * [1]
        sort_key = lambda r: (r[0][-1], r[0][0], r[0][1])
        for _x, _y in sorted(zip(x, y), key=sort_key):
            winner = _x[0] if _y else _x[1]
            loser = _x[1] if _y else _x[0]
            timestamp = _x[-1]
            glicko2_estimator.rating_model.update(winner, loser, timestamp)
        glicko2_estimator._fit = True

        # Output
        ratings = glicko2_estimator.rating_model.to_frame()
        ratings = ratings[ratings["valid_to"].isna()].drop(
            columns=["valid_from", "valid_to"]
        )
        ratings[["rating", "rating_deviation", "rating_volatility"]] = ratings[
            "rating"
        ].apply(lambda x: pd.Series(x))
        # Separate the model ratings and problem ratings
        problem_ratings = (
            ratings[ratings["key"].str.startswith("problem_")]
            .assign(key=lambda row: row["key"].apply(lambda key: key.split("_")[1]))
            .rename({"key": "problem"}, axis=1)
        )
        pdf_contest = pdf_contest.merge(
            problem_ratings.rename(columns={"problem": "index"}), on="index", how="left"
        )
        pdf.loc[
            pdf["contestId"] == contest_id,
            ["rating", "rating_deviation", "rating_volatility"],
        ] = pdf_contest[["rating", "rating_deviation", "rating_volatility"]].values

    # Format solutions
    pdf["source_0"] = pdf["source_0"].progress_apply(format_and_remove_comments)
    pdf["source_1"] = pdf["source_1"].progress_apply(format_and_remove_comments)
    pdf["source_2"] = pdf["source_2"].progress_apply(format_and_remove_comments)

    Dataset.from_pandas(pdf.reset_index(drop=True)).save_to_disk(
        os.path.join(dataset_save_path, f"chunk_{chunk_index}")
    )


if __name__ == "__main__":
    # Set up argument parsing
    argparser = argparse.ArgumentParser()
    argparser.add_argument(
        "--dataset_save_path",
        type=str,
        default="./data/Codeforces/chunks/",
    )
    argparser.add_argument(
        "--chunk_index",
        type=int,
        help="Chunk index to process",
    )
    argparser.add_argument(
        "--chunk_size",
        type=int,
        default=500,
        help="Chunk size to process",
    )
    args = argparser.parse_args()
    main(
        args.dataset_save_path,
        args.chunk_index,
        args.chunk_size,
    )
