import torch
import lightning as pl

from sklearn.metrics import (
    accuracy_score, 
    jaccard_score,
    mean_absolute_error
)

import sys
sys.path.append("../..")

import yaml

import lightning as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint


from utils import (
    find_best_checkpoint,
    set_default_if_unset,
    save_model_parameters
)

from locationencoder import LocationEncoder3D
from data import ACE_TS_DataModule
import torch
import numpy as np
import random

from utils.utils import get_output_root, get_project_root
import pandas as pd
import wandb

import matplotlib.pyplot as plt

torch.set_float32_matmul_precision('medium')

def plot_spatial_avg_of_target_vs_pred(model, title):
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.set_title(title)

    ax.plot(model.spatial_avg_prediction_vector, label='Avg Predicted')
    ax.plot(model.spatial_avg_target_vector, label='Avg Target')

    ax.set_xlabel('Time')
    ax.set_ylabel('Spatial Average')
    ax.legend()
    plt.show()
    return fig

def plot_error_of_spatial_avg_of_target_vs_pred(model, title):
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.set_title(title)

    ax.plot(model.spatial_avg_prediction_vector - model.spatial_avg_target_vector, label='Avg Predicted - Avg Target')

    ax.set_xlabel('Time')
    ax.set_ylabel('Spatial Average Error')
    ax.legend()
    plt.show()
    return fig

def on_test_epoch_start(self):
    """Initializes the variables needed for the test epoch."""
    self.pred_by_timestep = {}
    self.target_by_timestep = {}
    self.count_by_timestep = {}

    self.lons = np.array([])
    self.lats = np.array([])
    self.times = np.array([])

    self.y_preds = np.array([])
    self.labels = np.array([])

def on_test_epoch_end(self):
    # select 10 random timesteps to log
    focus_timesteps = np.random.choice(np.unique(self.times), 10)

    # get corresponding indices
    focus_indices = np.concatenate([np.where(self.times == t)[0] for t in focus_timesteps])
    
    # log data to wandb
    print("Logging table...")
    lonlatstime_cols = {
        "lons": self.lons[focus_indices],
        "lats": self.lats[focus_indices],
        "time": self.times[focus_indices],
    }

    label_cols = {f"target_{i}": self.labels[focus_indices,i] for i in range(self.labels.shape[1])}

    y_pred_cols = {f"y_pred_{i}": self.y_preds[focus_indices,i] for i in range(self.y_preds.shape[1])}

    table_data = {
        **lonlatstime_cols,
        **label_cols,
        **y_pred_cols,
    }

    df = pd.DataFrame(table_data)
    table = wandb.Table(dataframe=df)
    self.logger.experiment.log({"sample_test_results": table})

    # create numpy array of unique lon-lat pairs
    lonlat_pairs = np.unique(np.stack((self.lons, self.lats), axis=-1), axis=0)
    errors = np.array([])

    # loop through lon-lat pairs
    for lon, lat in lonlat_pairs:
        # get indices of all data points with the same lon-lat pair
        indices = np.where((self.lons == lon) & (self.lats == lat))[0]

        # get the corresponding target and prediction values
        target = self.labels[indices]
        y_pred = self.y_preds[indices]

        # calculate the mean error for the current lon-lat pair
        error = np.mean(target - y_pred, axis=1).mean()
        errors = np.append(errors, error)
    
    temp_avg_df = pd.DataFrame({
        "lons": lonlat_pairs[:,0],
        "lats": lonlat_pairs[:,1],
        "overall_error": errors
    })

    temp_avg_table = wandb.Table(dataframe=temp_avg_df)
    self.logger.experiment.log({"temporal_averages_table": temp_avg_table})

    # log metrics related to spatial and temperal averages of errors
    ## spatial averages
    self.spatial_avg_prediction_vector = np.array([np.nanmean(pred_sum / count, axis=-1)  
                                           for pred_sum, count in zip(self.pred_by_timestep.values(), self.count_by_timestep.values())])
    self.spatial_avg_target_vector = np.array([np.nanmean(target_sum / count, axis=-1)
                                       for target_sum, count in zip(self.target_by_timestep.values(), self.count_by_timestep.values())])


    # create pandas df of spatial averages
    spatial_average_df = pd.DataFrame({"avg_target" : self.spatial_avg_target_vector, 
        "avg_prediction" : self.spatial_avg_prediction_vector
        })
    # create wandb table of spatial averages
    spatial_average_table = wandb.Table(dataframe=spatial_average_df)
    
    # log table of spacial averages to wandb
    self.logger.experiment.log({"table of spatial averages": spatial_average_table})

    # create spatial average metrics
    MAE_of_spatial_average = np.nanmean(np.abs(self.spatial_avg_target_vector - self.spatial_avg_prediction_vector))
    RMSE_of_spatial_average = np.sqrt(np.nanmean((self.spatial_avg_target_vector - self.spatial_avg_prediction_vector)**2))

    self.log("test_MAE_of_spatial_average", MAE_of_spatial_average, on_step=False, on_epoch=True)
    self.log("test_RMSE_of_spatial_average", RMSE_of_spatial_average, on_step=False, on_epoch=True)

    ## temporal averages
    MAE_of_temporal_average = np.mean(temp_avg_df["overall_error"].abs())
    RMSE_of_temporal_average = np.sqrt(np.mean(temp_avg_df["overall_error"]**2))

    self.log("test_MAE_of_temporal_average", MAE_of_temporal_average, on_step=False, on_epoch=True)
    self.log("test_RMSE_of_temporal_average", RMSE_of_temporal_average, on_step=False, on_epoch=True)

    # plot figures corresponding to average metrics
    ## spatial averages
    fig = plot_spatial_avg_of_target_vs_pred(self, title="8 Temperature Variables: Avg Predicted vs Avg Target")
    fig2 = plot_error_of_spatial_avg_of_target_vs_pred(self, title="8 Temperature Variables: Spatial Avg Predicted - Spatial Avg Target")
    
    wandb.log({"Spatial Avg of Predicted vs Target": wandb.Image(fig)})
    wandb.log({"Error of Spatial Averages": wandb.Image(fig2)})

    ## temporal averages
    fig3, ax3 = plt.subplots(figsize=(10, 5))
    ax3.set_title("Error of Temporal Averages")
    ax3.scatter(temp_avg_df["lats"], temp_avg_df["lons"], c=temp_avg_df["overall_error"], cmap='viridis', s=2, alpha=0.5)
    ax3.set_xlabel("Longitude")
    ax3.set_ylabel("Latitude")
    # add colorbar
    fig3.colorbar(ax3.collections[0], ax=ax3, orientation='vertical') #, values=np.linspace(0, temp_avg_df["MAE"].max(), 20))
    plt.show()

    wandb.log({"Error of Temporal Averages": wandb.Image(fig3)})


def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0):
    lonlats, label = batch
    label = label * self.stds.cuda() + self.means.cuda()
    y_pred = outputs["y_pred"]
    num_targets = y_pred.shape[-1]

    for i in np.unique(lonlats[:,2].cpu().numpy()):
        mask = ~np.equal(lonlats[:,2].cpu().numpy(), i)
        mask = np.repeat(mask, num_targets).reshape(-1, num_targets)

        # ind = self.timestep_dict[i]
        if i not in self.pred_by_timestep:
            self.pred_by_timestep[i] = np.ma.masked_array(y_pred.cpu(), mask).sum(axis=0)
        else:
            self.pred_by_timestep[i] += np.ma.masked_array(y_pred.cpu(), mask).sum(axis=0)

        if i not in self.target_by_timestep:
            self.target_by_timestep[i] = np.ma.masked_array(label.cpu().numpy(), mask).sum(axis=0)
        else:
            self.target_by_timestep[i] += np.ma.masked_array(label.cpu().numpy(), mask).sum(axis=0)

        if i not in self.count_by_timestep:
            self.count_by_timestep[i] = np.ma.masked_array(~mask, mask).sum(axis=0)
        else:
            self.count_by_timestep[i] += np.ma.masked_array(~mask, mask).sum(axis=0)

    # append the current batch to the internal data structures
    self.lons = np.append(self.lons, lonlats[:,0].cpu().numpy())
    self.lats = np.append(self.lats, lonlats[:,1].cpu().numpy())
    self.times = np.append(self.times, lonlats[:,2].cpu().numpy())

    # self.lonslatstimes = np.concatenate((self.lonslatstimes, lonlats.cpu().numpy()), axis=0)
    self.y_preds = np.concatenate((self.y_preds, y_pred.cpu().numpy()), axis=0) if len(self.y_preds) > 0 else y_pred.cpu().numpy()
    self.labels = np.concatenate((self.labels, label.cpu().numpy()), axis=0) if len(self.labels) > 0 else label.cpu().numpy()
    
def fit_ace_ts(args):
    """
    Fits a time series model using the given file and arguments.

    Args:
        file2read (str): The file to read the data from.
        args (Namespace): The command-line arguments.

    Returns:
        tuple: A tuple containing the trained location encoder model, the trainer object,
               the data module, and the hyperparameters.
    """
    # set the seed for reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    # create the data module
    datamodule = ACE_TS_DataModule(
            **args.datamodule_args,
            )    
    datamodule.prepare_data()

    hparams = {}

    # "harmonize" hparams and args
    hparams['optimizer'] = {}
    hparams['optimizer']['lr'] = args.lr
    hparams['optimizer']['wd'] = args.wd
    hparams['legendre_polys'] = args.locationencoder_args["legendre_polys"]
    hparams['harmonics_calculation'] = args.locationencoder_args["harmonics_calculation"]
    hparams['num_classes'] = len(args.datamodule_args["variable_selection"])
    hparams['patience'] = args.patience
    hparams['regression'] = args.regression
    hparams["max_epochs"] = args.max_epochs
    hparams["min_radius"] = args.min_radius

    hparams = set_default_if_unset(hparams, "max_radius", 360)
    
    args.locationencoder_args["number_of_timesteps"] = datamodule.num_timesteps

    # create the location encoder model
    locationencoder = LocationEncoder3D(
        **args.locationencoder_args,
        hparams=hparams
    )

    # overwrite locadionencoder methods
    locationencoder.on_test_epoch_start = lambda : on_test_epoch_start(locationencoder) # MONKEY PATCH
    locationencoder.on_test_epoch_end = lambda : on_test_epoch_end(locationencoder) # MONKEY PATCH
    locationencoder.on_test_batch_end = lambda outputs, batch, batch_idx: on_test_batch_end(locationencoder, outputs, batch, batch_idx, dataloader_idx=0)  # MONKEY PATCH

    # configure training
    # checkpoint_dir = (get_output_root() / 
    #         args.results_dir / 
    #         locationencoder.artifact_path / datamodule.artifact_path
    #     )
    # checkpoint_dir.mkdir(parents=True, exist_ok=True)

    # if args.resume_ckpt_from_results_dir:
    #     resume_checkpoint = find_best_checkpoint(checkpoint_dir,
    #                                              f'{args.locationencoder_args["positional_embedding_type"]}-{args.locationencoder_args["combined_encoding_args"]["name"]}',
    #                                              verbose=True)
    # else:
    #     resume_checkpoint = None

    callbacks = [
        EarlyStopping(monitor="val_loss", mode="min", patience=args.patience)
    ]

    # if args.save_model:
    #     callbacks += [ModelCheckpoint(
    #         dirpath=checkpoint_dir,
    #         monitor='val_loss',
    #         filename=f'{args.locationencoder_args["positional_embedding_type"]}-{args.locationencoder_args["combined_encoding_args"]["name"]}' + '-{val_loss:.2f}',
    #         save_last=False
    #     )]

    logger = None

    if args.log_wandb:
        logger = pl.pytorch.loggers.WandbLogger(
        project=args.wandb_project,
        save_dir=args.output_root,
    )
    
    # use GPU if it is available
    accelerator = args.accelerator
    devices = 1
    if args.gpus == -1 or args.gpus == [-1]:
        devices = 'auto'
    else:
        devices = args.gpus

    if torch.cuda.is_available():
        accelerator = 'gpu'

    print(f"using gpus: {devices}")

    datamodule.setup('fit')
    locationencoder.means = datamodule.means
    locationencoder.stds = datamodule.stds
    locationencoder.mean_var = (locationencoder.stds**2).mean()

    trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        log_every_n_steps=5,
        callbacks=callbacks,
        accelerator=accelerator,
        devices=devices,
        logger=logger,
        precision=64,
        num_sanity_val_steps=0,
        )

    trainer.fit(model=locationencoder,
                datamodule=datamodule,
                # ckpt_path=resume_checkpoint
                )

    # Save model parameters after fitting
    save_model_parameters(trainer, locationencoder)

    return (locationencoder, trainer, datamodule)
