import importlib
from os.path import join

import torch
from torch.utils.data import DataLoader

import numpy as np
import scipy.signal as signal

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm

from tqdm import tqdm

from einops import rearrange


def get_dataset_evaluation(args):
    """
    Retrieve dataloaders and visualization function. 
    
    Example:
        load_data, visualize_data = initialize_data_functions(args)
        dataloaders, dataset = load_data(config.dataset, config.loader)
    """
    try:
        if args.variant != 0:
            dataset = f'{args.dataset}{args.variant}'
        else:
            dataset = args.dataset
    except:
        dataset = args.dataset
    dataset_evaluation_module = importlib.import_module(f'evaluate.{dataset}')
    dataset_evaluation_class = getattr(dataset_evaluation_module, 'DatasetEvaluation')
    return dataset_evaluation_class


def get_evaluation_loaders(dataloaders, config):
    eval_dataloaders = [
        DataLoader(dataloader.dataset, 
                   shuffle=False,
                   batch_size=config.loader.batch_size,
                   num_workers=0)
        for dataloader in dataloaders
    ]
    return eval_dataloaders


def informer_MAE(pred, true):
    return np.mean(np.abs(pred-true))

def informer_MSE(pred, true):
    return np.mean((pred-true)**2)

def informer_RMSE(pred, true):
    return np.sqrt(informer_MSE(pred, true))


def forecast(model, dataloader, input_transform, mse_loss,
             device, return_true=True, output_transform=None,
             forecast_rmse=True, args=None):
    # Also assume that dataloader is not shuffled
    mae_loss = torch.nn.L1Loss(reduction='none')
    mse_loss = torch.nn.MSELoss(reduction='none')
    
    try:
        mean = dataloader.dataset.standardization['means'][0]
        std  = dataloader.dataset.standardization['stds'][0]
    except AttributeError:
        mean = 0.; std = 1.;
        
    model.to(device)
    model.eval()
    
    total_y_true = []
    total_y_pred = []
    total_rmse = []
    total_mse  = []
    total_mae  = []
    
    total_rmse_transformed = []
    total_mse_transformed  = []
    total_mae_transformed  = []
    
    total_y_true_informer = []
    total_y_pred_informer = []    
    total_rmse_informer = []
    total_mse_informer  = []
    total_mae_informer  = []
    
    model.joint_train()
    
    with torch.no_grad():
        for batch_ix, data in enumerate(tqdm(dataloader, leave=False, desc='Evaluation forecasting')):
            x, y, *z = data  
            
            # For multivariate
            if args.features == 'M':
                x = rearrange(x,'b l d -> (b d) l').unsqueeze(-1)
                y = rearrange(y,'b l d -> (b d) l').unsqueeze(-1)
            
            # Transform batch data 
            x = input_transform(x)
                
            x = x.to(device)          
            y_pred = model(x, (args.lag, args.horizon, args.lag))                
                
            x = x.cpu()
            # Transform back batch data
            y_pred = output_transform(y_pred)
            
            total_y_true_informer.append(y.cpu())
            total_y_pred_informer.append(y_pred.cpu())
            
            total_rmse_transformed.append(
                torch.sqrt(mse_loss(y_pred.cpu(), y.cpu()).mean(1))
            )
            total_mse_transformed.append(mse_loss(y_pred.cpu(), y.cpu()).mean(1))
            total_mae_transformed.append(mae_loss(y_pred.cpu(), y.cpu()).mean(1))
            
            try:  
                if args.dataset != 'arima':
                    loc = z[-1] if args.dataset == 'monash' else None  
                    y_pred = dataloader.dataset.inverse_transform(y_pred, loc)  
                    y      = dataloader.dataset.inverse_transform(y, loc)  
                    mean = 0.
                    std = 1.
                    reverse_transformed = True
                else:
                    reverse_transformed = False
            except Exception as e:
                print(e)
                reverse_transformed = False
                pass
            
            if batch_ix == 0:
                total_y_true.append(y[0, :, 0])
                total_y_true.append(y[1:, -1, 0])

                total_y_pred.append(y_pred[0, :, 0])
                total_y_pred.append(y_pred[1:, -1, 0])

            else:
                total_y_true.append(y[:, -1, 0])
                total_y_pred.append(y_pred[:, -1, 0])
            
            if forecast_rmse and not reverse_transformed:
                y_pred = y_pred * std + mean 
                y      = y * std + mean

            total_rmse.append(
                torch.sqrt(mse_loss(y_pred.cpu(), y.cpu()).mean(1))
            )
            total_mse.append(mse_loss(y_pred.cpu(), y.cpu()).mean(1))
            total_mae.append(mae_loss(y_pred.cpu(), y.cpu()).mean(1))
            
            torch.cuda.empty_cache()
            
    total_y_true = torch.cat(total_y_true).cpu()
    total_y_pred = torch.cat(total_y_pred).cpu()
    
    
    total_y_true_informer = torch.cat(total_y_true_informer, dim=0).numpy()
    total_y_pred_informer = torch.cat(total_y_pred_informer, dim=0).numpy()
    
    if forecast_rmse:
        total_y_true = total_y_true * std + mean
        total_y_pred = total_y_pred * std + mean
    
    all_rmse = torch.sqrt(mse_loss(total_y_pred,
                                   total_y_true)).mean()
    avg_rmse = torch.cat(total_rmse).mean()
    avg_mse  = torch.cat(total_mse).mean()
    avg_mae  = torch.cat(total_mae).mean()
    
    avg_rmse_transformed = torch.cat(total_rmse_transformed).mean()
    avg_mse_transformed  = torch.cat(total_mse_transformed).mean()
    avg_mae_transformed  = torch.cat(total_mae_transformed).mean()
    
    avg_rmse_informer = informer_RMSE(total_y_pred_informer, total_y_true_informer)
    avg_mse_informer = informer_MSE(total_y_pred_informer, total_y_true_informer)
    avg_mae_informer = informer_MAE(total_y_pred_informer, total_y_true_informer)
    
    metrics_informer = (avg_rmse_informer, avg_mse_informer, avg_mae_informer)
    

    print('Informer metrics:')
    print(f'- RMSE: {metrics_informer[0]}')
    print(f'- MSE:  {metrics_informer[1]}')
    print(f'- MAE:  {metrics_informer[2]}')
    print(f'Dataset RMSE:        {all_rmse}')
    print(f'Batch-wise avg RMSE: {avg_rmse}')
    
    if return_true:
        return (total_y_pred, total_y_true), (avg_rmse, avg_mse, avg_mae), (avg_rmse_transformed, avg_mse_transformed, avg_mae_transformed), metrics_informer
    return total_y_pred, (avg_rmse, avg_mse, avg_mae), (avg_rmse_transformed, avg_mse_transformed, avg_mae_transformed)


def plot_forecast(model, dataloaders, splits, input_transform, 
                  mse_loss, device, forecast_rmse, output_transform=None,
                  args=None, axes=None, show=True, save=False):
    rmse_type = 'Forecast' if forecast_rmse else 'Standardized'

    if axes is None:
        fig, axes = plt.subplots(1, len(splits), 
                                 figsize=(6.4 * len(splits), 4.8))
    split_metrics = {split: {} for split in splits}
    for split_ix, split in enumerate(splits):
        print(f'Split: {split}')
        (y_hat, y_true), metrics, metrics_t, metrics_i = forecast(model,
                                                       dataloaders[split_ix],
                                                       input_transform,
                                                       mse_loss,
                                                       device,
                                                       return_true=True,
                                                       forecast_rmse=forecast_rmse,
                                                       output_transform=output_transform,
                                                       args=args)
        axes[split_ix].plot(y_true, label='ground-truth', color='tab:orange')
        axes[split_ix].plot(y_hat, label='prediction', color='tab:blue', linestyle='--')

        axes[split_ix].legend()
        rmse = metrics[0]
        axes[split_ix].set_title(f'{split} split forecasts, forecast RMSE: {metrics[0]:.3f}', size=15)
        
        split_metrics[split] = {'rmse': metrics[0], 'mse': metrics[1], 'mae': metrics[2],
                                'rmse_transformed': metrics_t[0], 'mse_transformed': metrics_t[1], 'mae_transformed': metrics_t[2],
                                'rmse_informer': metrics_i[0], 'mse_informer': metrics_i[1], 'mae_informer': metrics_i[2],
                               }
        print(f'------------------------')
    return split_metrics  
   