import os
import csv
import json
import sys
import warnings

warnings.filterwarnings("ignore")
import torch
import numpy as np
from functools import partial
import random
import argparse
import pandas as pd
from dotenv import load_dotenv
from gift_eval.data import Dataset
from gluonts.ev.metrics import (
    MAE,
    MAPE,
    MASE,
    MSE,
    MSIS,
    ND,
    NRMSE,
    RMSE,
    SMAPE,
    MeanWeightedSumQuantileLoss,
)
from gluonts.model import evaluate_model
from gluonts.time_feature import get_seasonality

# from gift_leaderboard.src.utils import get_args, set_seed
from gift_wrapper import SSM_Gift_Wrapper
from gift_utils import ValidWrapper
from modeling_flowstate import FlowStateForPrediction


class Dummy_Parser:
    def add_argument(self, name, nargs=None, type=None, default=None):
        setattr(self, name[2:], default)

    def parse_args(self, args):
        return self


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
else:
    parser = Dummy_Parser()

workdir = "."
datasets_dir = "./datasets/gift-eval" # Adjust accordingly

parser.add_argument("--model_name", nargs="?", type=str, default="FlowState")
parser.add_argument("--model_dir", nargs="?", type=str, default=os.path.join(workdir, "models"))
parser.add_argument("--config_dir", nargs="?", type=str, default=os.path.join(workdir, "configs"))
parser.add_argument("--out_dir", nargs="?", type=str, default=os.path.join(workdir, "results"))
parser.add_argument("--datasets_dir", nargs="?", type=str, default=datasets_dir)
parser.add_argument(
    "--dataset_properties",
    nargs="?",
    type=str,
    default=os.path.join(os.path.dirname(os.path.realpath(__file__)), "eval", "dataset_properties.json"),
)
parser.add_argument("--batch_size", nargs="?", type=int, default=64)
parser.add_argument("--device", nargs="?", type=str, default="cuda")
parser.add_argument("--seed", nargs="?", type=int, default=0)
C = parser.parse_args(sys.argv[1:])

load_dotenv()

short_datasets = "m4_yearly m4_quarterly m4_monthly m4_weekly m4_daily m4_hourly electricity/15T electricity/H electricity/D electricity/W solar/10T solar/H solar/D solar/W hospital covid_deaths us_births/D us_births/M us_births/W saugeenday/D saugeenday/M saugeenday/W temperature_rain_with_missing kdd_cup_2018_with_missing/H kdd_cup_2018_with_missing/D car_parts_with_missing restaurant hierarchical_sales/D hierarchical_sales/W LOOP_SEATTLE/5T LOOP_SEATTLE/H LOOP_SEATTLE/D SZ_TAXI/15T SZ_TAXI/H M_DENSE/H M_DENSE/D ett1/15T ett1/H ett1/D ett1/W ett2/15T ett2/H ett2/D ett2/W jena_weather/10T jena_weather/H jena_weather/D bitbrains_fast_storage/5T bitbrains_fast_storage/H bitbrains_rnd/5T bitbrains_rnd/H bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
med_long_datasets = "electricity/15T electricity/H solar/10T solar/H kdd_cup_2018_with_missing/H LOOP_SEATTLE/5T LOOP_SEATTLE/H SZ_TAXI/15T M_DENSE/H ett1/15T ett1/H ett2/15T ett2/H jena_weather/10T jena_weather/H bitbrains_fast_storage/5T bitbrains_rnd/5T bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"

base_row = [
    "dataset",
    "model",
    "eval_metrics/MSE[mean]",
    "eval_metrics/MSE[0.5]",
    "eval_metrics/MAE[mean]",
    "eval_metrics/MAE[0.5]",
    "eval_metrics/MASE[0.5]",
    "eval_metrics/MAPE[0.5]",
    "eval_metrics/sMAPE[0.5]",
    "eval_metrics/MSIS",
    "eval_metrics/RMSE[mean]",
    "eval_metrics/NRMSE[mean]",
    "eval_metrics/ND[0.5]",
    "eval_metrics/mean_weighted_sum_quantile_loss",
    "domain",
    "num_variates",
]


def GetPredictor(pred_length, n_ch, freq, device="cpu", predictor=None, domain=None, nd=False):
    if predictor is None:
        # load checkpoint from training
        ckpt = torch.load(os.path.join(C.model_dir, C.model_name), weights_only=False, map_location=C.device)
        config = ckpt["config"]
        config.device = device
        predictor = FlowStateForPrediction._from_config(config)
        predictor.model.load_state_dict(ckpt["model"])
        predictor = predictor.to(device)
    else:
        config = predictor.config
    config.min_context = 0
    predictor = SSM_Gift_Wrapper(
        predictor, pred_length, n_ch=n_ch, batch_size=C.batch_size, f=freq, device=C.device, domain=domain, no_daily=nd
    )
    return predictor, config


# Ensure the output directory exists
os.makedirs(os.path.join(C.out_dir, "gift_eval"), exist_ok=True)


def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def run(model=None, max_len=1e8, verbose=True, save=True, valid=False, model_name=None):
    set_seed(C.seed)
    if model_name is not None:
        C.model_name = model_name
    # Get union of short and med_long datasets
    all_datasets = sorted(set(short_datasets.split() + med_long_datasets.split()))
    dataset_properties_map = json.load(open(C.dataset_properties))

    # Instantiate the metrics
    metrics = [
        MSE(forecast_type="mean"),
        MSE(forecast_type=0.5),
        MAE(forecast_type="mean"),
        MAE(forecast_type=0.5),
        MASE(),
        MAPE(),
        SMAPE(),
        MSIS(),
        RMSE(),
        NRMSE(),
        ND(),
        MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
    ]

    # ## Evaluation
    # Define the path for the CSV file
    csv_file_path = os.path.join(C.out_dir, "gift_eval", "results.csv").replace('.pt', '')

    pretty_names = {
        "saugeenday": "saugeen",
        "temperature_rain_with_missing": "temperature_rain",
        "kdd_cup_2018_with_missing": "kdd_cup_2018",
        "car_parts_with_missing": "car_parts",
    }

    if not os.path.exists(csv_file_path) and save:
        with open(csv_file_path, "a", newline="") as csvfile:
            writer = csv.writer(csvfile)

            # Write the header
            writer.writerow(base_row)
    if save:
        df_res_done = pd.read_csv(csv_file_path)
        done_datasets = df_res_done["dataset"].values
    else:
        done_datasets = []
    df_res = pd.DataFrame(columns=base_row)
    if verbose:
        print("Done datasets")
        print(done_datasets)

    excluded = []

    for ds_name in all_datasets:
        if ds_name in excluded:
            continue
        set_seed(C.seed)
        terms = ["short", "medium", "long"]
        # terms = ["short"]
        # terms = ["medium", "long"]
        for term in terms:
            if (term == "medium" or term == "long") and ds_name not in med_long_datasets.split():
                continue
            if verbose:
                print(f"Processing dataset: {ds_name}, term: {term}")

            if "/" in ds_name:
                ds_key = ds_name.split("/")[0]
                ds_freq = ds_name.split("/")[1]
                ds_key = ds_key.lower()
                ds_key = pretty_names.get(ds_key, ds_key)
            else:
                ds_key = ds_name.lower()
                ds_key = pretty_names.get(ds_key, ds_key)
                ds_freq = dataset_properties_map[ds_key]["frequency"]
            ds_config = f"{ds_key}/{ds_freq}/{term}"

            if ds_config in done_datasets:
                if verbose:
                    print(f"Done with {ds_config}. Skipping...")
                df_res = df_res._append(df_res_done.loc[df_res_done["dataset"] == ds_config], ignore_index=True)
                continue

            to_univariate = False if Dataset(name=ds_name, term=term, to_univariate=False).target_dim == 1 else True
            dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)
            dst = dataset.test_data if not valid else ValidWrapper(dataset.validation_dataset, dataset.prediction_length)
            all_lengths = []
            for x in dst:
                if len(x[0]["target"].shape) == 1:
                    all_lengths.append(len(x[0]["target"]))
                    num_channels = 1
                else:
                    all_lengths.append(x[0]["target"].shape[1])
                    num_channels = x[0]["target"].shape[0]

            if verbose:
                print(f"Dataset: {ds_name}, Freq = {dataset.freq}, H = {dataset.prediction_length}")

            no_daily = (
                "l2c" in ds_name
            )  # necessary to get correct seasonality for bizitobs_l2c datasets (which have no daily cycles)
            if "predictor" in locals():  # not necessary to reload the model for each dataset
                predictor, config = GetPredictor(
                    pred_length=dataset.prediction_length,
                    n_ch=num_channels,
                    freq=dataset.freq,
                    device=C.device,
                    predictor=predictor.model,
                    domain=dataset_properties_map[ds_key]["domain"],
                    nd=no_daily,
                )
            else:
                predictor, config = GetPredictor(
                    pred_length=dataset.prediction_length,
                    n_ch=num_channels,
                    freq=dataset.freq,
                    device=C.device,
                    predictor=model,
                    domain=dataset_properties_map[ds_key]["domain"],
                    nd=no_daily,
                )

            if verbose:
                print(f"Number of channels in the dataset {ds_name} =", num_channels)

            with torch.no_grad():
                # Evaluate
                res = evaluate_model(
                    predictor,
                    test_data=dst,
                    metrics=metrics,
                    batch_size=C.batch_size,
                    axis=None,
                    mask_invalid_label=True,
                    allow_nan_forecast=False,
                    # seasonality=season_length,
                )
            if verbose:
                print(f'MASE: {res["MASE[0.5]"][0]}')
            # Append the results to the CSV file
            row = [
                ds_config,
                C.model_name,
                res["MSE[mean]"][0],
                res["MSE[0.5]"][0],
                res["MAE[mean]"][0],
                res["MAE[0.5]"][0],
                res["MASE[0.5]"][0],
                res["MAPE[0.5]"][0],
                res["sMAPE[0.5]"][0],
                res["MSIS"][0],
                res["RMSE[mean]"][0],
                res["NRMSE[mean]"][0],
                res["ND[0.5]"][0],
                res["mean_weighted_sum_quantile_loss"][0],
                dataset_properties_map[ds_key]["domain"],
                dataset_properties_map[ds_key]["num_variates"],
            ]
            if save:
                with open(csv_file_path, "a", newline="") as csvfile:
                    writer = csv.writer(csvfile)
                    writer.writerow(row)
                if verbose:
                    print(f"Results for {ds_name} have been written to {csv_file_path}")
            df_res.loc[len(df_res)] = row

    # Print Results
    print(df_res.to_markdown())

    return df_res


if __name__ == "__main__":
    C.model_dir = "./"
    C.model_name = 'FlowState_ckpt.pt'
    C.device = 'cpu'

    C.batch_size = 64

    run(save=True, max_len=1e8, valid=False)
