import os
import pickle
from pathlib import Path

import geopandas as gpd
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.utils.data import DataLoader, ConcatDataset

import Normalization.NormalizerBuilder as NormalizerBuilder
from Dataloader.ERA5 import ERA5
from Dataloader.MetaStation import MetaStation
from Dataloader.MixData import MixData
from EvaluateModel import evaluate_model
from Modules.GNN.MPNN import MPNN
from Network.ERA5Network import ERA5Network
from Source.Network.stationNetwork import stationNetwork
from Settings.Settings import ModelType
import xarray as xr
from torch.optim.lr_scheduler import ReduceLROnPlateau
def Run(args):
    data_path = Path(args.data_path)
    output_saving_path = data_path / args.output_saving_path
    output_saving_path.mkdir(exist_ok=True, parents=True)
    show_progress_bar = args.show_progress_bar
    shapefile_path = args.shapefile_path

    if shapefile_path is None:
        lon_low, lon_up, lat_low, lat_up = args.coords
    else:
        shapefile_path = '' # Path to shapefile
        gdf = gpd.read_file(shapefile_path)
        bounds = gdf.total_bounds
        lon_low, lat_low, lon_up, lat_up = bounds
    lon_low-=5
    lon_up+=5
    lat_low-=1
    lat_up+=5
    back_hrs = args.back_hrs
    lead_hrs = args.lead_hrs
    whole_len = back_hrs + 1

    station_len = whole_len
    ERA5_len = whole_len + lead_hrs
    hidden_dim = args.hidden_dim
    lr = args.lr
    epochs = args.epochs
    batch_size = args.batch_size
    eval_interval = args.eval_interval
    weight_decay = args.weight_decay
    model_type = args.model_type
    station_control_ratio = args.station_control_ratio
    n_years = args.n_years
    n_passing = args.n_passing
    n_neighbors_m2m = args.n_neighbors_m2m

    n_neighbors_e2m = args.n_neighbors_e2m

    figures_path = output_saving_path / 'figures'
    figures_path.mkdir(exist_ok=True, parents=True)

    print('Experiment Configuration', flush=True)

    for k, v in vars(args).items():
        print(f'{k}: {v}', flush=True)

    ##### Set Random Seed #####
    np.random.seed(42)
    ##### Get Device #####
    device = 'cuda:0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

    ##### Load Data #####
    meta_station = MetaStation(lat_low, lat_up, lon_low, lon_up, n_years, station_control_ratio,
                               shapefile_path=shapefile_path, data_path=data_path)

    station_network = stationNetwork(meta_station, n_neighbors_m2m)
    n_stations = station_network.n_stations
    if n_neighbors_e2m > 0:
        era5_stations = ERA5(meta_station.lat_low, meta_station.lat_up, meta_station.lon_low, meta_station.lon_up, 2020,
                             region='Northeastern',
                             data_path=data_path).data
        
        era5_network = ERA5Network(era5_stations, station_network, n_neighbors_e2m)
    else:
        era5_network = None

    years = list(range(2016, 2022))
    if model_type == ModelType.GNN:
        Data_List = [MixData(year, back_hrs, lead_hrs, meta_station, station_network, n_neighbors_m2m, era5_network,
                             data_path=data_path) for year in years]
        if era5_network is not None:
            era5_network.era5_pos = torch.Tensor(Data_List[0].station_data[['lon', 'lat']].to_dataarray().values.T)
            era5_network.era5_lons = torch.Tensor(Data_List[0].station_data[['lon']].to_dataarray().values.T)
            era5_network.era5_lats = torch.Tensor(Data_List[0].station_data[['lat']].to_dataarray().values.T)

    n_dataset = dict()

    loaders = dict()

    Train_Dataset = ConcatDataset(Data_List[:5])
    Valid_Dataset = ConcatDataset(Data_List[5:6])
    Test_Dataset = ConcatDataset(Data_List[6:7])

    n_dataset['train'] = sum(len(ds) for ds in Data_List[:5])
    n_dataset['val'] = sum(len(ds) for ds in Data_List[5:6])
    n_dataset['test'] = sum(len(ds) for ds in Data_List[6:7])

    loaders['train'] = DataLoader(Train_Dataset, batch_size=batch_size, shuffle=True)
    loaders['val'] = DataLoader(Valid_Dataset, batch_size=batch_size, shuffle=False)
    loaders['test'] = DataLoader(Test_Dataset, batch_size=batch_size, shuffle=False)

    n_dataset['train'] = len(Train_Dataset)
    n_dataset['val'] = len(Valid_Dataset)
    n_dataset['test'] = len(Test_Dataset)

    n_stations = station_network.n_stations
    n_train_stations = station_network.n_stations
    n_val_stations = station_network.n_stations
    n_test_stations = station_network.n_stations

    station_norm_dict, era5_norm_dict = NormalizerBuilder.get_normalizers(Data_List, era5_network)

    

    ##### Define Model #####
    model = MPNN(
            n_passing,
            lead_hrs=lead_hrs,
            n_node_features_m=1 * station_len,
            n_node_features_e=1 * ERA5_len,
            n_out_features=1,
            hidden_dim=hidden_dim
        ).to(device)

    nn_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print('Parameter Number: ', nn_params, flush=True)
    print(' ', flush=True)
    loss_function = nn.MSELoss(reduction='sum')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ###### Training ######
    train_losses = []
    valid_losses = []
    test_losses = []
    test_tp_mses = []
    test_tp_maes = []

    min_valid_loss = 9999999999
    
    for epoch in range(epochs):    

        test_epoch = (epoch + 1) % eval_interval == 0 or epoch == 0

        call_evaluate = lambda dataset, save: evaluate_model(
            model,
            loaders[dataset],
            station_norm_dict,
            era5_norm_dict,
            device,
            lead_hrs,
            loss_function=loss_function,
            optimizer=optimizer,
            save=save,
            model_type=model_type,
            station_type=dataset,
            show_progress_bar=show_progress_bar
        )

        if test_epoch:
            MAE_tp_sum, MSE_tp_sum, Pred_train, Target_train = call_evaluate('train', True)

        else:
            MAE_tp_sum, MSE_tp_sum = call_evaluate('train', False)

        train_loss = (MSE_tp_sum ) / (n_dataset['train'] * n_train_stations)
        train_losses.append(train_loss)

        if test_epoch:
            MAE_tp_sum, MSE_tp_sum, Pred_val, Target_val = call_evaluate('val', True)
        else:
            MAE_tp_sum, MSE_tp_sum = call_evaluate('val', False)

        valid_loss = (MSE_tp_sum) / (n_dataset['val'] * n_val_stations)
        valid_losses.append(valid_loss)

        print('Epoch: %d train_loss[%.10f] valid_loss[%.10f]' % (epoch + 1, train_loss, valid_loss), flush=True)
        print(' ', flush=True)

        if test_epoch:
            MAE_tp_sum, MSE_tp_sum,Pred, Target = call_evaluate('test', True)

            Preds = dict()
            Preds['train'] = Pred_train
            Preds['val'] = Pred_val
            Preds['test'] = Pred

            Targets = dict()
            Targets['train'] = Target_train
            Targets['val'] = Target_val
            Targets['test'] = Target

            test_tp_mae = MAE_tp_sum / (n_dataset['test'] * n_test_stations)
            test_tp_mse = MSE_tp_sum / (n_dataset['test'] * n_test_stations)

            test_tp_maes.append(test_tp_mae)
            test_tp_mses.append(test_tp_mse)

            test_loss = (MSE_tp_sum) / (n_dataset['test'] * n_test_stations)
            test_loss_mae = (MAE_tp_sum) / (n_dataset['test'] * n_test_stations)
            test_losses.append(test_loss)

            if valid_loss < min_valid_loss:
                min_valid_loss = valid_loss

                serialized_data = pickle.dumps(station_network)
                with open(output_saving_path / f'station_network_min.pkl', 'wb') as file:
                    file.write(serialized_data)

                serialized_data = pickle.dumps(Targets)
                with open(output_saving_path / f'Targets_min.pkl', 'wb') as file:
                    file.write(serialized_data)
                
                serialized_data = pickle.dumps(Preds)
                with open(output_saving_path / f'Preds_min.pkl', 'wb') as file:
                    file.write(serialized_data)

                np.save(os.path.join(output_saving_path, f'min_test_loss_mse.npy'), test_loss)
                np.save(os.path.join(output_saving_path, f'min_test_loss_mae.npy'), test_loss_mae)
                np.save(os.path.join(output_saving_path, f'min_test_tp_mae.npy'), test_tp_mae)
                np.save(os.path.join(output_saving_path, f'min_test_tp_mse.npy'), test_tp_mse)

            print('Evaluation Report: test_tp_mae[%.3f] test_tp_mse[%.3f] ' % (
                test_tp_mae, test_tp_mse), flush=True)
            print(' ', flush=True)

            np.save(os.path.join(output_saving_path, f'station_train_test_tp_mae_epoch_{epoch + 1}.npy'),
                    MAE_tp_sum / n_dataset['test'])
            np.save(os.path.join(output_saving_path, f'station_train_test_tp_mse_epoch_{epoch + 1}.npy'),
                    MSE_tp_sum / n_dataset['test'])
            np.save(os.path.join(output_saving_path, f'station_train_test_preds_epoch_{epoch + 1}.npy'), Pred)

            torch.save(model.state_dict(), os.path.join(output_saving_path, f'model_epoch_{epoch + 1}.pt'))

    ##### Save #####
    train_losses = np.array(train_losses)
    np.save(os.path.join(output_saving_path, 'train_losses.npy'), train_losses)

    valid_losses = np.array(valid_losses)
    np.save(os.path.join(output_saving_path, 'valid_losses.npy'), valid_losses)

    np.save(os.path.join(output_saving_path, 'test_tp_mses.npy'), test_tp_mses)

    np.save(os.path.join(output_saving_path, 'test_tp_maes.npy'), test_tp_maes)


    ##### Plotting #####
    plot_metric(train_losses, valid_losses, test_losses, eval_interval, 'MSE', figures_path)


def plot_metric(train_losses, valid_losses, test_losses, eval_interval, metric_name, output_path, y_range=None):
    fig, axs = plt.subplots(1, 1, figsize=(10, 5))

    epochs = len(train_losses)
    axs.plot(np.arange(1, epochs + 1), train_losses, label='Train')
    axs.plot(np.arange(1, epochs + 1), valid_losses, label='Valid')

    x_part_1 = np.array([1])
    x_part_2 = np.arange(1, len(test_losses)) * eval_interval
    x_axis = np.concatenate([x_part_1, x_part_2])

    axs.plot(x_axis, test_losses, label='Test')

    if y_range != None:
        axs.set_ylim(y_range[0], y_range[1])

    axs.legend()
    axs.grid()
    axs.set_xlabel('Epochs')
    axs.set_ylabel(metric_name)

    axs.set_title(metric_name + ' Plot')

    plt.savefig(os.path.join(output_path, '_'.join(metric_name.split(' ')) + '_plot.png'))
    plt.close()
