
import os
import sys
import math
import numpy as np
import tempfile
import math
import numpy as np
import pandas as pd
from tsfm_public import (
    TimeSeriesPreprocessor,
    TinyTimeMixerForPrediction,
    TrackingCallback,
    count_parameters,
    get_datasets,
)
from tsfm_public.toolkit.lr_finder import optimal_lr_finder

import torch
from torch.utils.data import ConcatDataset

from transformers import Trainer, TrainingArguments
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments

 
if sys.platform == 'darwin':
    device = 'mps'
else:
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
 


# ------------------------------------------------------------------------------------------------------------------------------------------------------- #
# Data Preparation UNIVARIATE
# TODO: Improve for Multivariate

def prepare_dataset_for_TTM(object_numpy, labels_data, tsp) :

    # Get data
    # object_numpy = np.load(path)
    if len(object_numpy.shape) == 2 : #i.e. no windows, especially for train, val and test data
        object_numpy = np.expand_dims(object_numpy, 0)
    # object_numpy[channel,window,time]

    # Get label columns
    # path_labels = os.path.dirname(path)+'/labels.npy'
    # labels_data = np.load(path_labels)
    # print('labels', labels_data)

    list_object_pth = []
    for window_id in range(0, object_numpy.shape[1]):
        object_dataframe = pd.DataFrame(object_numpy[:,window_id,:].transpose(), columns=labels_data)
        # print(object_dataframe)
        split_params = {"train": 1,
                        "test": 0
                        }

        object_pth, test_1 , test_2 = get_datasets(
            tsp,
            object_dataframe,
            split_params,
            stride = 1,
        )

        list_object_pth.append(object_pth)

    object_pth = ConcatDataset(list_object_pth)

    return(object_pth)


# ------------------------------------------------------------------------------------------------------------------------------------------------------- #
# CALL TTM 

def call_ttm (
        pretrained_model_name_or_path,
        revision,
        type,
        freeze_backbone
        ):
    return()

# ------------------------------------------------------------------------------------------------------------------------------------------------------- #
# TRAINING 

def train_ttm (
        training_dataset,
        validation_dataset,
        TTM_model,
        TTM_MODEL_REVISION,
        decoder_mode,
        timeseriespreprocessor,
        OUTPUT_MODEL_PATH,
        device = device,
        num_epochs = 200, 
        batch_size = 64,
        dataloader_num_workers = 8,
        early_stopping_patience = 5,
        early_stopping_threshold = 0,
        float16 = False
        ) :
    
    
    # ---------------------------------------------------------------------------------------------------------------------------------------------- #
    # Load Model to finetune

    print('loaded model: ', TTM_model, TTM_MODEL_REVISION )
    print('decoder mode: ', decoder_mode )
    univariate_finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(
    pretrained_model_name_or_path = TTM_model,
    revision=TTM_MODEL_REVISION,
    num_input_channels=timeseriespreprocessor.num_input_channels,
    decoder_mode = decoder_mode,  # exog: set to mix_channel for mixing channels in history
    prediction_channel_indices=timeseriespreprocessor.prediction_channel_indices,
    exogenous_channel_indices=timeseriespreprocessor.exogenous_channel_indices,
    # fcm_context_length=1,  # exog: indicates lag length to use in the exog fusion. for Ex. if today sales can get affected by discount on +/- 2 days, mention 2
    # fcm_use_mixer=True,  # exog: Try true (1st option) or false
    # fcm_mix_layers=2,  # exog: Number of layers for exog mixing
    # enable_forecast_channel_mixing=True,  # exog: set true for exog mixing
    # fcm_prepend_past=True,  # exog: set true to include lag from history during exog infusion. 
    ).to(device)
    print(device)
    
    if float16:
        univariate_finetune_forecast_model.half()
    
    # ---------------------------------------------------------------------------------------------------------------------------------------------- #
    # Freeze backbone

    print(
        "Number of params before freezing backbone",
        count_parameters(univariate_finetune_forecast_model),
    )

    # Freeze the backbone of the model
    for param in univariate_finetune_forecast_model.backbone.parameters():
        param.requires_grad = False

    # Count params
    print(
        "Number of params after freezing the backbone",
        count_parameters(univariate_finetune_forecast_model),
    )

    # ---------------------------------------------------------------------------------------------------------------------------------------------- #
    # Fix parameters for finetuning

    # Important parameters
    print("num epochs =", num_epochs)
    print("batch size =", batch_size)
    print("dataloader num workers  =", dataloader_num_workers)
    print("device  =", device)

    learning_rate, univariate_finetune_forecast_model = optimal_lr_finder(
        univariate_finetune_forecast_model,
        training_dataset,
        batch_size=batch_size,
        device='cpu',
        enable_prefix_tuning=False,
    )
    print("OPTIMAL SUGGESTED LEARNING RATE =", learning_rate)

    # ---------------------------------------------------------------------------------------------------------------------------------------------- #
    # Fix parameters for finetuning
    
    print(f"Using learning rate = {learning_rate}")
    univariate_finetune_forecast_args = TrainingArguments(
        output_dir=os.path.join(OUTPUT_MODEL_PATH, "output"),
        overwrite_output_dir=True,
        learning_rate=learning_rate,
        num_train_epochs=num_epochs,
        do_eval=True,
        eval_strategy="epoch",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        dataloader_num_workers=dataloader_num_workers,
        report_to=None,
        save_strategy="epoch",
        logging_strategy="epoch",
        # logging_strategy="steps",
        # evaluation_strategy="steps",
        save_total_limit=1,
        logging_dir=os.path.join(OUTPUT_MODEL_PATH, "logs"),  # Make sure to specify a logging directory
        load_best_model_at_end=True,  # Load the best model when training ends
        metric_for_best_model="eval_loss",  # Metric to monitor for early stopping
        greater_is_better=False,  # For loss
        # max_steps=100
        use_mps_device = True if sys.platform == 'darwin' else False
    )
    
        

    # Create the early stopping callback
    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=early_stopping_patience,  # Number of epochs with no improvement after which to stop
        early_stopping_threshold=early_stopping_threshold,  # Minimum improvement required to consider as improvement
    )
    tracking_callback = TrackingCallback()

    # Optimizer and scheduler
    optimizer = AdamW(univariate_finetune_forecast_model.parameters(), lr=learning_rate)
    scheduler = OneCycleLR(
        optimizer,
        learning_rate,
        epochs=num_epochs,
        steps_per_epoch=math.ceil(len(training_dataset) / (batch_size)),
    )
    # optimizer = AdamW(univariate_finetune_forecast_model.parameters(), lr=scheduler)
    # ---------------------------------------------------------------------------------------------------------------------------------------------- #
    # Finetune
    univariate_finetune_forecast_trainer = Trainer(
        model=univariate_finetune_forecast_model,
        args=univariate_finetune_forecast_args,
        train_dataset=training_dataset,
        eval_dataset=validation_dataset,
        callbacks=[early_stopping_callback, tracking_callback],
        optimizers=(optimizer, scheduler),
        # optimizers=optimizer,
    )

    # Fine tune
    univariate_finetune_forecast_trainer.train()

    # ---------------------------------------------------------------------------------------------------------------------------------------------- #
    # Save the model 
    
    #save_path = os.path.join(OUT_DIR + str(target_dataset) + "/univariate_finetune_forecast_trainer")
    os.makedirs(OUTPUT_MODEL_PATH, exist_ok=True)
    univariate_finetune_forecast_trainer.save_model(OUTPUT_MODEL_PATH)



# ------------------------------------------------------------------------------------------------------------------------------------------------------- #
# EVALUATION

from sklearn.metrics import mean_squared_error

def compute_rmse(y_true, y_pred):
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    return rmse

def compute_mase(true, pred):
    """Compute MASE (Mean Absolute Scaled Error)"""
    numerator = np.mean(np.abs(true - pred))
    denominator = np.mean(np.abs(true[1:] - true[:-1]))  # Naive one-step ahead forecast
    return numerator / denominator if denominator != 0 else np.nan

def compute_wql(true, pred, quantiles=[0.1, 0.5, 0.9]):
    """Compute Weighted Quantile Loss (WQL)"""
    total_loss = 0
    for q in quantiles:
        errors = true - pred
        total_loss += np.mean(np.maximum(q * errors, (q - 1) * errors))
    return total_loss / len(quantiles)

def compute_h1_distance(y_true, y_pred, dt=1.0, rescale_derivative=True):
    """
    Compute the Sobolev H1 distance between two 1D curves.

    Parameters:
    - y_true: array-like of shape (T,), true values of the curve
    - y_pred: array-like of shape (T,), predicted values of the curve
    - dt: float, time step between samples (default 1.0)
    - rescale_derivative: bool, whether to rescale the derivative term to match the scale of the L2 term

    Returns:
    - h1_distance: float, the H1 distance
    """

    # Ensure input arrays are 1D
    y_true = np.asarray(y_true).flatten()
    y_pred = np.asarray(y_pred).flatten()

    # Compute L2 (RMSE^2) part
    l2_dist_sq = mean_squared_error(y_true, y_pred)

    # Compute numerical derivative using finite differences
    dy_true = np.gradient(y_true, dt)
    dy_pred = np.gradient(y_pred, dt)

    # Compute L2 distance between derivatives
    d_l2_dist_sq = mean_squared_error(dy_true, dy_pred)

    # Optional rescaling to make derivative term comparable
    if rescale_derivative:
        std_true = np.std(y_true)
        std_d_true = np.std(dy_true)
        if std_d_true > 0:
            scaling_factor = std_true / std_d_true
            d_l2_dist_sq *= scaling_factor ** 2

    h1_dist = np.sqrt(l2_dist_sq + d_l2_dist_sq)
    return h1_dist


def evaluate_ttm_model(test_dataset, model_path, tsp, seed=None):
    """
    test_data_path: Path to numpy array
    context_length: Length of context window
    forecast_length: Length of forecast window
    seed: Random seed for deterministic behavior (optional)
    """
    # test_dataset = prepare_dataset_for_TTM (test_data_path, tsp)

    # Load the saved model
    loaded_model = TinyTimeMixerForPrediction.from_pretrained(model_path)
    temp_dir = tempfile.mkdtemp()
    # zeroshot_trainer
    loaded_model_trainer = Trainer(
        model=loaded_model,
        args=TrainingArguments(
            output_dir=temp_dir,
            per_device_eval_batch_size=64,
        ),
    )
    eval_loss = loaded_model_trainer.evaluate(test_dataset)['eval_loss']
    print(f"Eval Loss: {eval_loss}")
    
    # get pred forecasts
    pred_val = loaded_model_trainer.predict(test_dataset).predictions[0]

    # get true forecasts
    future_values = [item['future_values'] for item in test_dataset]
    true_val = torch.stack(future_values)

    # Step 6: Compute the metrics and store the predictions (y_pred)
    rmse_list, mase_list, wql_list, h1_list = [], [], [], []
    y_pred_list = []  # Store the predicted values

    for sample in range(0, true_val.shape[0]):
        list_rmse_per_channel = []
        list_mase_per_channel = []
        list_wql_per_channel = []
        list_h1_per_channel = []
        for channel in range(0, pred_val.shape[2]):
            rmse = compute_rmse(true_val[sample, :, channel], pred_val[sample, :, channel])
            mase = compute_mase(np.array(true_val[sample, :, channel]), np.array(pred_val[sample, :, channel]))
            wql = compute_wql(np.array(true_val[sample, :, channel]), np.array(pred_val[sample, :, channel]))
            h1 = compute_h1_distance(np.array(true_val[sample, :, channel]), np.array(pred_val[sample, :, channel]))
            list_rmse_per_channel.append(rmse)
            list_mase_per_channel.append(mase)
            list_wql_per_channel.append(wql)
            list_h1_per_channel.append(h1)
        rmse_list.append(list_rmse_per_channel)
        mase_list.append(list_mase_per_channel)
        wql_list.append(list_wql_per_channel)
        h1_list.append(list_h1_per_channel)

    # Step 7: Calculate the average metrics across all windows
    avg_rmse = np.mean(rmse_list)
    avg_mase = np.mean(mase_list)
    avg_wql = np.mean(wql_list)
    avg_h1 = np.mean(h1_list)

    print(f"Average RMSE: {avg_rmse}")
    print(f"Average MASE: {avg_mase}")
    print(f"Average WQL: {avg_wql}")
    print(f"Average SobolevH1: {avg_h1}")

    # Return results as a dictionary including the metrics and the data (X, y, and y_pred)
    return {
        "eval_loss": eval_loss,
        "avg_rmse": avg_rmse,
        "avg_mase": avg_mase,
        "avg_wql": avg_wql,
        "avg_h1": avg_h1,
        "rmse_list": rmse_list,
        "mase_list": mase_list,
        "wql_list": wql_list,
        "h1_list": h1_list,
        "true_val": true_val,     # Context windows
        "pred_val": pred_val,     # Ground truth forecast windows
        "y_pred": y_pred_list  # Predicted forecast windows
    }
