"""
Dataset specific evaluation
"""
import torch
import matplotlib.pyplot as plt

# Data
from dataloaders import initialize_data_functions
# Evaluation
from evaluate import get_evaluation_loaders, plot_transfer_from_coeffs, plot_forecast


class DatasetEvaluation():
    def __init__(self, model, input_transform, output_transform, 
                 criterions, device, config, args, horizons=None):
        if horizons is None:
            self.horizons = [int(args.horizon * 0.5 + 24 * ix) for ix in range(0, 21)]
        else:
            self.horizons = horizons
            
        self.lag = args.lag

        self.model = model
        self.input_transform = input_transform
        self.output_transform = output_transform
        self.criterions = criterions
        self.device = device
        
        self.args = args
        self.config = config
        
        self.splits = ['train', 'val', 'test']
        
    
    def evaluate(self, wandb=None):
        horizon_metrics = {}
        for horizon in self.horizons:
            print(f'HORIZON: {horizon}')
            metrics, forecast_plots = self.evaluate_horizon(horizon)
            
            horizon_metrics[f'horizon_{horizon}'] = metrics
            if wandb is not None:
                wandb.log({f'forecast_plot-horizon_{horizon}': forecast_plots})
                
        if wandb is not None:
            wandb.log(horizon_metrics)
        return horizon_metrics
        
            
    def evaluate_horizon(self, horizon):
        
        # Setup data
        self.args.horizon = horizon
        self.config.dataset['size'] = [self.args.lag, horizon, horizon]
        load_data, visualize_data = initialize_data_functions(self.args)
        eval_loaders, dataset = load_data(self.config.dataset, self.config.loader, self.args)
        eval_loaders = get_evaluation_loaders(eval_loaders, self.config)
        dataloaders_by_split = {split: eval_loaders[ix] 
                                for ix, split in enumerate(self.splits)}
        # Evaluate model
        self.model.reset_horizon(horizon)
        n_plots = len(self.splits) # train, val, test + freq. response
        fig, axes = plt.subplots(1, n_plots, figsize=(6.4 * n_plots, 4.8))
        split_metrics = plot_forecast(self.model, [eval_loaders[-1]], [self.splits[-1]], 
                                      self.input_transform, self.criterions['mse'], self.device, 
                                      forecast_rmse=True, output_transform=self.output_transform,
                                      args=self.args, axes=axes, show=False, save=False)
        torch.cuda.empty_cache()
        return split_metrics, fig
        
        
        
        
    

