import os
from typing import Optional
from copy import deepcopy
import json
from pathlib import Path

import numpy as np
import pandas as pd
from syntherela.typing import Tables
from syntherela.data import load_tables, remove_sdv_columns

from relbench.base import Database, Dataset, Table
from syntherela.metadata import Metadata


def update_stypes_cache(cache_dir: str, dataset_name: str, table_name: str, column_name: str):
    """Update stypes.json to remove specified column from specified table."""
    stypes_cache_path = Path(cache_dir) / dataset_name / "stypes.json"
    
    if stypes_cache_path.exists():
        try:
            # Read existing stypes.json
            with open(stypes_cache_path, "r") as f:
                col_to_stype_dict = json.load(f)
            
            # Remove column from table if it exists
            if table_name in col_to_stype_dict and column_name in col_to_stype_dict[table_name]:
                col_to_stype_dict[table_name].pop(column_name)
                
                # Save updated stypes.json
                with open(stypes_cache_path, "w") as f:
                    json.dump(col_to_stype_dict, f, indent=2)
                
                print(f"Updated stypes.json: removed '{column_name}' column from '{table_name}' table")
                
        except (json.JSONDecodeError, IOError) as e:
            print(f"Warning: Could not update stypes.json at {stypes_cache_path}: {e}")
    else:
        print(f"Warning: stypes.json not found at {stypes_cache_path}")


def append_test_set(
    tables_train: Tables, tables_test: Tables, metadata: Metadata
) -> Tables:
    tables = {}
    for table in tables_train.keys():
        id_columns = metadata.get_column_names(table, sdtype="id")
        # Add test and train prefix to the id columns
        for column in id_columns:
            tables_train[table][column] = tables_train[table][column].apply(
                lambda x: f"train_{x}"
            )
            tables_test[table][column] = tables_test[table][column].apply(
                lambda x: f"test_{x}"
            )
        # Add the concatenated dataframe to the tables dict
        tables[table] = pd.concat(
            [tables_train[table], tables_test[table]], ignore_index=True
        )
    return tables


def cut_off_set(
    tables_train: Tables,
    metadata: Metadata,
    test_timestamp: pd.Timestamp,
    before: bool = True,
) -> Tables:
    tables = {}
    for table in tables_train.keys():
        datetime_columns = metadata.get_column_names(table, sdtype="datetime")
        tables[table] = tables_train[table]
        for column in datetime_columns:
            if before:
                tables[table] = tables[table][
                    (tables[table][column] < test_timestamp)
                    | (tables[table][column].isna())
                ]
            else:
                tables[table] = tables[table][
                    (tables[table][column] >= test_timestamp)
                    | (tables[table][column].isna())
                ]
    return tables


def get_tables_and_metadata(
    dataset: str, method: str, run_id: int
) -> tuple[Tables, Metadata]:
    data_type = "original" if method == "ORIGINAL" else "synthetic"
    path = os.path.join("data", data_type, dataset)
    if method != "ORIGINAL":
        path = os.path.join(path, method, str(run_id), "sample1")

    metadata_path = os.path.join("data", "original", dataset, "metadata.json")
    metadata = Metadata.load_from_json(metadata_path)

    tables = load_tables(path, metadata)

    return tables, metadata


def keep_only_seen_values(
    tables: Tables, tables_test: Tables, metadata: Metadata
) -> Tables:
    # no feature engineering necessary 🤡
    for table in tables.keys():
        for column in tables[table].columns:
            if column in metadata.get_column_names(
                table_name=table, sdtype="categorical"
            ) or column in metadata.get_column_names(
                table_name=table, sdtype="boolean"
            ):
                values = tables_test[table][column].unique()
                # remove nan, na_values, and empty strings
                values = [str(v) for v in values if v == v]
                # set categories with pd.categorical from test to train
                tables[table][column] = pd.Categorical(
                    tables[table][column].astype(str), categories=values
                )

    return tables


class RossmannDataset(Dataset):
    name = "rossmann_subsampled"
    val_timestamp = pd.Timestamp("2014-09-20")
    test_timestamp = pd.Timestamp("2014-10-01")

    from_timestamp = pd.Timestamp("2014-07-31")
    upto_timestamp = pd.Timestamp("2014-11-01")

    def __init__(
        self,
        predict_column_task_config: dict = {},
        method: str = "ORIGINAL",
        run_id: int = 0,
        type: str = "train",
        cache_dir: Optional[str] = os.path.expanduser("~/.cache/relbench_examples"),
    ):
        super().__init__(cache_dir)
        self.method = method
        self.run_id = run_id
        self.type = type
        self.cache_dir = None

    def make_db(self) -> Database:
        tables_train, metadata = get_tables_and_metadata(
            self.name, self.method, self.run_id
        )

        tables_test = load_tables(
            os.path.join("data", "original", "rossmann"), metadata
        )
        tables_test, metadata = remove_sdv_columns(tables_test, metadata)

        tables_train = keep_only_seen_values(tables_train, tables_test, metadata)

        if self.type == "test":
            tables = tables_test
        else:
            tables = tables_train

        store_df = tables["store"]
        historical_df = tables["historical"]
        historical_df["Date"] = pd.to_datetime(historical_df["Date"], format="%Y-%m-%d")

        db = Database(
            table_dict={
                "store": Table(
                    df=store_df,
                    fkey_col_to_pkey_table={},
                    pkey_col="Store",
                ),
                "historical": Table(
                    df=historical_df,
                    fkey_col_to_pkey_table={
                        "Store": "store",
                    },
                    pkey_col="Id",
                    time_col="Date",
                ),
            }
        )

        db = db.from_(self.from_timestamp)
        db = db.upto(self.upto_timestamp)

        return db


class AirbnbDataset(Dataset):
    name = "airbnb-simplified_subsampled"
    val_timestamp = pd.Timestamp("2014-05-15")
    test_timestamp = pd.Timestamp("2014-06-01")

    from_timestamp = pd.Timestamp("2014-01-01")
    upto_timestamp = pd.Timestamp("2014-07-01")

    def __init__(
        self,
        predict_column_task_config: dict = {},
        method: str = "ORIGINAL",
        run_id: int = 0,
        type: str = "train",
        cache_dir: Optional[str] = os.path.expanduser("~/.cache/relbench_examples")
    ):
        super().__init__(cache_dir)
        self.method = method
        self.run_id = run_id
        self.type = type

        # Update stypes cache to remove columns that will be popped
        if cache_dir is not None:
            update_stypes_cache(cache_dir, self.name, "users", "date_first_booking")
        self.cache_dir = None

    def make_db(self) -> Database:
        tables_train, metadata = get_tables_and_metadata(
            self.name, self.method, self.run_id
        )

        tables_test = load_tables(
            os.path.join("data", "original", self.name, "test"), metadata
        )

        tables_test, metadata = remove_sdv_columns(tables_test, metadata)
        tables_test = cut_off_set(tables_test, metadata, self.test_timestamp, False)

        tables_train = keep_only_seen_values(tables_train, tables_test, metadata)
        tables_train = cut_off_set(tables_train, metadata, self.test_timestamp)

        tables_test = append_test_set(tables_train, tables_test, metadata)

        if self.type == "test":
            tables = tables_test
        else:
            tables = tables_train

        users_df = tables["users"]
        sessions_df = tables["sessions"]

        users_df.pop("date_first_booking")

        users_df["country_destination"] = users_df["country_destination"] == "NDF"

        db = Database(
            table_dict={
                "users": Table(
                    df=users_df,
                    fkey_col_to_pkey_table={},
                    pkey_col="id",
                    time_col="date_account_created",
                ),
                "sessions": Table(
                    df=sessions_df,
                    fkey_col_to_pkey_table={
                        "user_id": "users",
                    },
                ),
            }
        )

        db = db.from_(self.from_timestamp)
        db = db.upto(self.upto_timestamp)

        return db


class WalmartDataset(Dataset):
    name = "walmart_subsampled"
    val_timestamp = pd.Timestamp("2012-01-24")
    test_timestamp = pd.Timestamp("2012-02-01")

    from_timestamp = pd.Timestamp("2012-01-01")
    upto_timestamp = pd.Timestamp("2012-03-01")

    def __init__(
        self,
        predict_column_task_config: dict = {},
        method: str = "ORIGINAL",
        run_id: int = 0,
        type: str = "train",
        cache_dir: Optional[str] = os.path.expanduser("~/.cache/relbench_examples"),
    ):
        super().__init__(cache_dir)
        self.method = method
        self.run_id = run_id
        self.type = type
        if cache_dir is not None:
            update_stypes_cache(cache_dir, self.name, "depts", "Dept")
        self.cache_dir = None

    def make_db(self) -> Database:
        tables_train, metadata = get_tables_and_metadata(
            self.name, self.method, self.run_id
        )
        tables_test = load_tables(os.path.join("data", "original", "walmart"), metadata)
        tables_test, metadata = remove_sdv_columns(tables_test, metadata)

        # tables_train = keep_only_seen_values(tables_train, tables_test, metadata)

        tables = None
        if self.type == "test":
            tables = tables_test
        else:
            tables = tables_train

        depts_df = tables["depts"]
        stores_df = tables["stores"]
        features_df = tables["features"]

        depts_df["Date"] = pd.to_datetime(depts_df["Date"], format="%Y-%m-%d")
        # sort by Date
        depts_df = depts_df.sort_values(by=["Store", "Dept", "Date"], ascending=True)
        features_df = features_df.sort_values(by=["Store", "Date"], ascending=True)
        depts_df["primary_key"] = range(len(depts_df))


        depts_df = depts_df.drop(columns=["Dept"])
        # features_df = features_df.drop(columns=["IsHoliday"])
        # features_df = features_df[["Date", "Store", "Temperature"]]

        db = Database(
            table_dict={
                "depts": Table(
                    df=depts_df,
                    fkey_col_to_pkey_table={
                        "Store": "stores",
                    },
                    time_col="Date",
                    pkey_col="primary_key",
                ),
                "stores": Table(
                    df=stores_df,
                    fkey_col_to_pkey_table={},
                    pkey_col="Store",
                ),
                "features": Table(
                    df=features_df,
                    fkey_col_to_pkey_table={
                        "Store": "stores",
                    },
                    time_col="Date",
                ),
            }
        )

        db = db.from_(self.from_timestamp)
        db = db.upto(self.upto_timestamp)

        return db


class F1Dataset(Dataset):
    name = "f1_subsampled"
    val_timestamp = pd.Timestamp("2005-01-01")
    test_timestamp = pd.Timestamp("2010-01-01")

    # from_timestamp = pd.Timestamp("1990-01-01")
    # upto_timestamp = pd.Timestamp("2010-01-01")

    def __init__(
        self,
        predict_column_task_config: dict = {},
        method: str = "ORIGINAL",
        run_id: int = 0,
        type: str = "train",
        cache_dir: Optional[str] = os.path.expanduser("~/.cache/relbench_examples")
    ):
        super().__init__(cache_dir)
        self.method = method
        self.run_id = run_id
        self.type = type
        
        # Update stypes cache to remove columns that will be popped
        if cache_dir is not None:
            update_stypes_cache(cache_dir, self.name, "races", "year")
            update_stypes_cache(cache_dir, self.name, "races", "datetime")
        self.cache_dir = None

    def make_db(self) -> Database:
        tables_train, metadata = get_tables_and_metadata(
            self.name, self.method, self.run_id
        )
        tables_test = load_tables(os.path.join("data", "original", "f1"), metadata)
        tables_test, metadata = remove_sdv_columns(tables_test, metadata)

        tables_train = keep_only_seen_values(tables_train, tables_test, metadata)

        if self.type == "test":
            tables = tables_test
        else:
            tables = tables_train

        circuits = tables["circuits"]
        drivers = tables["drivers"]
        results = tables["results"]
        races = tables["races"]
        standings = tables["standings"]
        constructors = tables["constructors"]
        constructor_results = tables["constructor_results"]
        constructor_standings = tables["constructor_standings"]
        qualifying = tables["qualifying"]

        # Add date column to races by extracting date from datetime
        # races["date"] = pd.to_datetime(pd.to_datetime(races["datetime"]).dt.date)
        # races["time"] = pd.to_datetime(races["datetime"]).dt.time

        races.pop("year")
        races["date"] = pd.to_datetime(races.pop("datetime"))

        qualifying = qualifying.merge(
            races[["raceId", "date"]], on="raceId", how="left"
        )

        # # Subtract a day from the date to account for the fact
        # # that the qualifying time is the day before the main race
        qualifying["date"] = qualifying["date"] - pd.Timedelta(days=1)

        # Replace "\N" with NaN in results tables
        results = results.replace(r"^\\N$", np.nan, regex=True)

        # Replace "\N" with NaN in circuits tables, especially
        # for the column `alt` which has 3 rows of "\N"
        circuits = circuits.replace(r"^\\N$", np.nan, regex=True)
        # Convert alt from string to float
        circuits["alt"] = circuits["alt"].astype(float)

        # Convert non-numeric values to NaN in the specified column
        results["rank"] = pd.to_numeric(results["rank"], errors="coerce")
        results["number"] = pd.to_numeric(results["number"], errors="coerce")
        results["grid"] = pd.to_numeric(results["grid"], errors="coerce")
        results["position"] = pd.to_numeric(results["position"], errors="coerce")
        results["points"] = pd.to_numeric(results["points"], errors="coerce")
        results["laps"] = pd.to_numeric(results["laps"], errors="coerce")
        results["milliseconds"] = pd.to_numeric(
            results["milliseconds"], errors="coerce"
        )
        results["fastestLap"] = pd.to_numeric(results["fastestLap"], errors="coerce")

        # Convert drivers date of birth to datetime
        drivers["dob"] = pd.to_datetime(drivers["dob"])

        tables = {}

        tables["races"] = Table(
            df=pd.DataFrame(races),
            fkey_col_to_pkey_table={
                "circuitId": "circuits",
            },
            pkey_col="raceId",
            time_col="date",
        )

        tables["circuits"] = Table(
            df=pd.DataFrame(circuits),
            fkey_col_to_pkey_table={},
            pkey_col="circuitId",
            time_col=None,
        )

        tables["drivers"] = Table(
            df=pd.DataFrame(drivers),
            fkey_col_to_pkey_table={},
            pkey_col="driverId",
            time_col=None,
        )

        tables["results"] = Table(
            df=pd.DataFrame(results),
            fkey_col_to_pkey_table={
                "raceId": "races",
                "driverId": "drivers",
                "constructorId": "constructors",
            },
            pkey_col="resultId",
            time_col="date",
        )

        tables["standings"] = Table(
            df=pd.DataFrame(standings),
            fkey_col_to_pkey_table={"raceId": "races", "driverId": "drivers"},
            pkey_col="driverStandingsId",
            time_col="date",
        )

        tables["constructors"] = Table(
            df=pd.DataFrame(constructors),
            fkey_col_to_pkey_table={},
            pkey_col="constructorId",
            time_col=None,
        )

        tables["constructor_results"] = Table(
            df=pd.DataFrame(constructor_results),
            fkey_col_to_pkey_table={"raceId": "races", "constructorId": "constructors"},
            pkey_col="constructorResultsId",
            time_col="date",
        )

        tables["constructor_standings"] = Table(
            df=pd.DataFrame(constructor_standings),
            fkey_col_to_pkey_table={"raceId": "races", "constructorId": "constructors"},
            pkey_col="constructorStandingsId",
            time_col="date",
        )

        tables["qualifying"] = Table(
            df=pd.DataFrame(qualifying),
            fkey_col_to_pkey_table={
                "raceId": "races",
                "driverId": "drivers",
                "constructorId": "constructors",
            },
            pkey_col="qualifyId",
            time_col="date",
        )

        db = Database(tables)

        # db = db.from_(self.from_timestamp)
        # db = db.upto(self.upto_timestamp)

        return db


class BerkaDataset(Dataset):
    name = "Berka_subsampled"
    val_timestamp = pd.Timestamp("1997-01-01")
    test_timestamp = pd.Timestamp("1998-01-01")

    def __init__(
        self,
        predict_column_task_config: dict = {},
        method: str = "ORIGINAL",
        run_id: int = 0,
        type: str = "train",
        cache_dir: Optional[str] = os.path.expanduser("~/.cache/relbench_examples"),
    ):
        super().__init__(cache_dir)
        self.method = method
        self.run_id = run_id
        self.type = type
        self.cache_dir = None

    def make_db(self) -> Database:
        tables_train, metadata = get_tables_and_metadata(
            self.name, self.method, self.run_id
        )
        tables_test = load_tables(os.path.join("data", "original", "Berka"), metadata)
        tables_test, metadata = remove_sdv_columns(tables_test, metadata)

        tables_train = keep_only_seen_values(tables_train, tables_test, metadata)

        if self.type == "test":
            tables = tables_test
        else:
            tables = tables_train

        account = tables["account"]
        card = tables["card"]
        client = tables["client"]
        disp = tables["disp"]
        district = tables["district"]
        loan = tables["loan"]
        order = tables["order"]
        trans = tables["trans"]

        # Only look if the loan was good (A, C) or bad (B, D)
        def remap_status(x):
            if x == "C":
                return "A"
            elif x == "D":
                return "B"
            else:
                return x

        loan.status = loan.status.apply(remap_status)
        loan.status = loan.status.map({"A": 0, "B": 1})

        tables = {}

        tables["account"] = Table(
            df=pd.DataFrame(account),
            fkey_col_to_pkey_table={
                "district_id": "district",
            },
            pkey_col="account_id",
            time_col="date",
        )

        tables["card"] = Table(
            df=pd.DataFrame(card),
            fkey_col_to_pkey_table={
                "disp_id": "disp",
            },
            pkey_col="card_id",
            time_col="issued",
        )

        tables["client"] = Table(
            df=pd.DataFrame(client),
            fkey_col_to_pkey_table={
                "district_id": "district",
            },
            pkey_col="client_id",
            time_col=None,
        )

        tables["disp"] = Table(
            df=pd.DataFrame(disp),
            fkey_col_to_pkey_table={
                "client_id": "client",
                "account_id": "account",
            },
            pkey_col="disp_id",
            time_col=None,
        )

        tables["district"] = Table(
            df=pd.DataFrame(district),
            fkey_col_to_pkey_table={},
            pkey_col="district_id",
            time_col=None,
        )

        tables["loan"] = Table(
            df=pd.DataFrame(loan),
            fkey_col_to_pkey_table={
                "account_id": "account",
            },
            pkey_col="loan_id",
            time_col="date",
        )

        tables["order"] = Table(
            df=pd.DataFrame(order),
            fkey_col_to_pkey_table={
                "account_id": "account",
            },
            pkey_col="order_id",
            time_col=None,
        )

        tables["trans"] = Table(
            df=pd.DataFrame(trans),
            fkey_col_to_pkey_table={
                "account_id": "account",
            },
            pkey_col="trans_id",
            time_col="date",
        )

        return Database(tables)
