from pathlib import Path
import pandas as pd
import numpy as np
from gluonts.dataset.common import ListDataset
from typing import List, Dict
import json

from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
from uni2ts.run import predict, predict_without_split
from werkzeug.utils import secure_filename
import os
import io
import base64
import matplotlib.pyplot as plt
from datetime import timedelta
from uni2ts.eval_util.plot import plot_single


def get_accuracy(forecast_value, label):
    """
    Calculate the accuracy of the forecast.
    """
    # Calculate the mean absolute percentage error
    mape = np.mean(np.abs(forecast_value - label) / np.abs(label))
    return mape


def get_highest_value(forecast_value, label):
    """
    Get the highest value from the forecast and label.
    """
    if not forecast_value:
        return 0 
    return max(forecast_value)


def calculate_the_portion_of_covered_values(forecast_upper_bound, forecast_lower_bound, label):
    """
    Calculate the portion of covered values.
    """
    # Calculate the portion of covered values
    covered = np.sum(np.logical_and(forecast_upper_bound > label, forecast_lower_bound < label))
    total = len(forecast_upper_bound)
    return covered / total


def calculate_the_portion_of_overestimated_values(forecast_upper_bound, label):
    """
    Calculate the portion of overestimated values.
    """
    # Calculate the portion of overestimated values
    overestimated = np.sum(forecast_upper_bound > label)
    total = len(forecast_upper_bound)
    return overestimated / total


def compute_trend_alignment(forecast_value, label):
    """
    Compute the trend alignment.
    """
    # Compute the trend alignment
    trend_alignment = np.sum(np.sign(forecast_value[1:] - forecast_value[:-1]) == np.sign(label[1:] - label[:-1])) / len(forecast_value)
    return trend_alignment


def process_forecasts(forecast, label, idx, inp) -> Dict:
    """
    Process GluonTS forecasts into a format suitable for frontend visualization.
    """

    if hasattr(forecast.start_date, 'to_timestamp'):
        start_date = forecast.start_date.to_timestamp()
    else:
        start_date = forecast.start_date

    # Ensure idx is DatetimeIndex
    if not isinstance(idx, pd.DatetimeIndex):
        idx = pd.DatetimeIndex(idx)

    # Now compare using the converted timestamp
    mask = idx >= start_date
    forecast_idx = idx[mask]

    context_idx_mask = idx < start_date
    context_idx = idx[context_idx_mask]

    print("Forecast Index:", len(forecast_idx))  
    print("Context Index:", len(context_idx))
    import pdb; pdb.set_trace()

    # analysis
    # import pdb; pdb.set_trace()
    prediction_mse = get_accuracy(forecast.quantile("0.5")[:len(label['target'])], label['target'])
    highest_value = get_highest_value(forecast.quantile("0.5")[:len(label['target'])], label['target'])
    portion_covered = calculate_the_portion_of_covered_values(forecast.quantile("0.95")[:len(label['target'])], forecast.quantile("0.05")[:len(label['target'])], label['target'])
    portion_overestimated = calculate_the_portion_of_overestimated_values(forecast.quantile("0.95")[:len(label['target'])], label['target'])
    trend_alignment = compute_trend_alignment(forecast.quantile("0.5")[:len(label['target'])], label['target'])
    print("Prediction MSE:", prediction_mse, "Highest Value:", highest_value, "Portion Covered:", portion_covered, "Portion Overestimated:", portion_overestimated, "Trend Alignment:", trend_alignment)

    result = {
        "prediction": forecast.quantile("0.5"),
        "ground_truth": label['target'],
        "index": forecast_idx.strftime('%Y-%m-%d %H:%M:%S').tolist(),
        "original_series": inp['target'][-200:],
        "contex_index": context_idx.strftime('%Y-%m-%d %H:%M:%S').tolist()[-200:],
        "prediction_intervals": {
            "0.5": [forecast.quantile("0.75"), forecast.quantile("0.25")],  # Changed to string
            "0.9": [forecast.quantile("0.95"), forecast.quantile("0.05")]   # Changed to string
        },
        "report":{
            "prediction_mse": {
                "value": prediction_mse,
                "text": "measures the average squared difference between forecasted and actual values. Lower values indicate better model performance."
            },
            "highest_value": {
                "value": highest_value,
                "text": "represents the highest observed value in the predicted data. This helps identify peak performance or outliers."
            },
            "portion_covered": {
                "value": portion_covered,
                "text": "indicates the percentage of ground truth values covered by forecast confidence intervals. Higher percentages suggest better reliability."
            },
            "portion_overestimated": {
                "value": portion_overestimated,
                "text": "shows the percentage of forecasts that overestimated the actual values. This helps evaluate model bias."
            },
            "trend_alignment": {
                "value": trend_alignment,
                "text": "reflects the alignment between forecasted and actual trends. Higher values indicate better predictive consistency.",
            },
            "highest_obsered_value": {
                "value": highest_value,
                "text": "represents the highest observed value in the oberserved data. This helps identify peak performance or outliers."
            },
            "highest_obsered_value": {
                "value": highest_value,
                "text": "represents the highest observed value in the oberserved data. This helps identify peak performance or outliers."
            },
            "min_obsered_value": {
                "value": highest_value,
                "text": "represents the minimum observed value in the predicted data. This helps identify peak performance or outliers."
            },
            "median_obsered_value": {
                "value": highest_value,
                "text": "represents the meidan observed value in the oberserved data. This helps identify peak performance or outliers."
            },
            "mean_obsered_value": {
                "value": highest_value,
                "text": "represents the mean observed value in the predicted data. This helps identify peak performance or outliers."
            },
        }
    }
    
    return result


def forecast_and_analyze():
    horizon = 50
    target = ' Prices'
    covariates = [' Prices']
    SIZE = "large"  # model size: choose from {'small', 'base', 'large'}
    PDT = horizon # prediction length: any positive integer
    CTX = 200  # context length: any positive integer
    PSZ = "auto"  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
    BSZ = 32  # batch size: any positive integer
    TEST = PDT  # test set length: any positive integer
    # indexRange = [0, 200]
    indexRange = [52217, 52417]
    TEST = horizon

    # Load and prepare data directly from the uploaded file
    df = pd.read_csv("/root/uni2ts-main/uploaded_files/BE.csv", index_col=0, parse_dates=True)
    # min_index, max_index = int(indexRange[0]), min(int(indexRange[1])+horizon, len(df))
    min_index, max_index = int(indexRange[0]), min(int(indexRange[1]), len(df))
    df = df.iloc[min_index:max_index]
    df_label = df.iloc[max_index:max_index+horizon]
    df_label = df_label.rename(columns={target: 'target'})
    print("indexRange:",min_index, max_index)

    # Process forecasts
    forecast, label, inp = predict_without_split(df, target, covariates, CTX, PDT, TEST, SIZE, BSZ, PSZ)
    result = process_forecasts(forecast, df_label, df.index, inp)
    print(result.keys())
    return 



if __name__ == "__main__":
    forecast_and_analyze()
