import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import xarray as xr
from Settings.Settings import ModelType

def evaluate_model(model,
                   data_loader,
                   station_norm_dict,
                   era5_norm_dict,
                   device,
                   lead_hrs,
                   loss_function=None,
                   optimizer=None,
                   save=False,
                   model_type=ModelType.GNN,
                   station_type='train',
                   show_progress_bar=False):
    is_train = station_type == 'train'
    MSE_tp_sum = 0

    MAE_tp_sum = 0

    if save == True:
        Pred_list = []
        Target_list = []

    if is_train:
        model.train()
    else:
        model.eval()


    with torch.set_grad_enabled(is_train):
        loopItems = tqdm(data_loader) if show_progress_bar else data_loader
        for k, sample in enumerate(loopItems):
            station_tp = sample[f'station_tp'].to(device)
    

            station_lon = sample[f'station_lon'].to(device)
            station_lat = sample[f'station_lat'].to(device)
            
            
            edge_index_m2m = sample[f'k_edge_index'].to(device)
            
            station_matrix_len = station_tp.shape[2]
            x_station_tp = station_tp[:, :, :station_matrix_len - lead_hrs]
            y_tp = station_tp[:, :, [-1]]
            station_x = x_station_tp.unsqueeze(3)
            if era5_norm_dict is not None:
                era5_tp = sample[f'era5_tp'].to(device)
                era5_lat = sample[f'era5_lat'].to(device)
                edge_index_e2m = sample[f'e2m_edge_index'].to(device)
                era5_x = era5_tp.unsqueeze(3)
                b, s, t, v = era5_x.shape
                era5_x = era5_x.view(b * s, t, v) if ((model_type == ModelType.MLP) or (model_type == ModelType.MPNN_MLP)) else era5_x.view(b, s, t * v)

            else:
                era5_lon = None
                era5_lat = None
                era5_x = None
                edge_index_e2m = None

            if is_train:
                optimizer.zero_grad()

            out = model(station_x,
                            station_lon,
                            station_lat,
                            edge_index_m2m,
                            era5_lon,
                            era5_lat,
                            era5_x,
                            edge_index_e2m)
            out_tp=out
            out= out_tp
            if is_train:
                ls = loss_function(out, y)
                ls.backward()
                optimizer.step()

            if save == True:
                Pred_list.append(out.detach().cpu().numpy())
                Target_list.append(y.detach().cpu().numpy())

            y_tp = y_tp.detach()

            out_tp = out_tp.detach()


            mse_tp = torch.sum(F.mse_loss(out_tp, y_tp, reduction='none'), dim=0).cpu().numpy()
            MSE_tp_sum = MSE_tp_sum + np.sum(mse_tp)

            mae_tp = torch.sum(F.l1_loss(out_tp, y_tp, reduction='none'), dim=0).cpu().numpy()
            MAE_tp_sum = MAE_tp_sum + np.sum(mae_tp)
            MAE_tp_sum = MSE_tp_sum

    if save == True:
        return MAE_tp_sum, MSE_tp_sum,  np.concatenate(Pred_list, axis=0), np.concatenate(
            Target_list, axis=0)

    else:
        return MAE_tp_sum, MSE_tp_sum
