import os
import numpy as np
from utils.constants import Constants
import copy
import csv
from utils.metrics import metric
current_dir = os.path.dirname(os.path.abspath(__file__))


class MakeTable:
    def __init__(self, args, model, run_type, results_path, setting, csv_save_path=None, target_station=None, train_phase=None, loop=None):
        sl_ll_pl_folder = 'sl' + setting.split('_sl')[1].split('_dm')[0]
        if csv_save_path is None:
            csv_save_path = os.path.join(current_dir, 'csv_files', sl_ll_pl_folder)

        self._args = args
        self._model = model
        self._run_type = run_type
        self._results_path = results_path
        self._test_results_path = 'test_results/'
        self._setting = setting
        self._csv_save_path = csv_save_path
        self._constants = Constants(self._args)
        if target_station is not None:
            self._target_station = target_station
        else:
            self._target_station = self._args.station

        if not os.path.exists(self._csv_save_path):
            os.makedirs(self._csv_save_path)
        self._all_stations = self._constants.all_stations
        self.train_phase = train_phase
        self.loop = loop
        if self.train_phase is None:
            self.train_phase = 'global'
        if self.loop is None:
            self.loop = self._args.global_loop

    def _get_station_path(self, station):
        if self._run_type == 'initial':
            a = self._setting.split('_sl')[0]
            b = self._setting.split('_sl')[1]
            path = os.path.join(a + station.replace(' ', '') + '_sl' + b, self._run_type)
            return path
        elif self._run_type == 'local_global':
            return os.path.join(self._setting, f"{self.train_phase}_{self.loop}", str(station))
        elif self._run_type == 'for_gnn':
            a = self._setting.split('_sl')[0]
            b = self._setting.split('_sl')[1]
            path = os.path.join(a + station.replace(' ', '') + '_sl' + b, self._run_type)
            return path
        else:
            raise NotImplementedError

    def _read_results(self, station, station_results_path):
        file_path = os.path.join(self._results_path, station_results_path)
        pred_npy = os.path.join(str(file_path), "pred.npy")
        true_npy = os.path.join(str(file_path), "true.npy")
        metrics_npy = os.path.join(str(file_path), "metrics.npy")
        try:
            pred = np.load(pred_npy)
            true = np.load(true_npy)
            metrics = np.load(metrics_npy)
        except FileNotFoundError:
            print('file not found, return None')
            pred = None
            true = None
            metrics = None
        return pred, true, metrics

    def _get_single_station_results(self, station):
        station_results_path = self._get_station_path(station)
        pred, true, metrics = self._read_results(station, station_results_path)
        return pred, true

    def save_to_csv(self, file_save_name=None):
        if file_save_name is None:
            if self._run_type == 'initial':
                file_save_name = f"{self._model}_{self._run_type}"
            else:
                if self._args.other_station in ['flow', 'st_flow']:
                    file_save_name = f"{self._model}_{self._run_type}_alpha{self._args.alpha}_{self._args.other_station}_{self._args.base_flow}"
                else:
                    file_save_name = f"{self._model}_{self._run_type}_alpha{self._args.alpha}_{self._args.other_station}"
        rows = []
        header = ['Station', 'NSE', 'RMSE', 'MAE', 'PBIAS', 'KGE', 'FLV', 'FHV', 'MSE', 'MAPE', 'MSPE', 'R2']
        rows.append(header)

        mae_list = []
        mse_list = []
        rmse_list = []
        mape_list = []
        mspe_list = []
        r2_list = []
        nse_list = []
        pbias_list = []
        kge_list = []
        flv_list = []
        fhv_list = []
        if self._run_type in ['for_gnn', 'local_global']:
            if self._target_station == 'all':
                pred, true, *_ = self._get_single_station_results('all')
                stations_list = self._all_stations
            else:
                pred, true, *_ = self._get_single_station_results(self._target_station)
                stations_list = [self._target_station]
            for station_i, station in enumerate(stations_list):
                if pred is None:
                    mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = (
                        np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan)
                else:
                    mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = metric(pred[:, :, station_i], true[:, :, station_i])
                mae_list.append(mae)
                mse_list.append(mse)
                rmse_list.append(rmse)
                mape_list.append(mape)
                mspe_list.append(mspe)
                r2_list.append(r2)
                nse_list.append(nse)
                pbias_list.append(pbias)
                kge_list.append(kge)
                flv_list.append(flv)
                fhv_list.append(fhv)
                rows.append([station, nse, rmse, mae, pbias, kge, flv, fhv, mse, mape, mspe, r2])
        else:
            for station_i, station in enumerate(self._all_stations):
                pred, true = self._get_single_station_results(station)
                if pred is None:
                    mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = (
                        np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan)
                else:
                    mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = metric(pred[:, :, -1], true[:, :, -1])
                mae_list.append(mae)
                mse_list.append(mse)
                rmse_list.append(rmse)
                mape_list.append(mape)
                mspe_list.append(mspe)
                r2_list.append(r2)
                nse_list.append(nse)
                pbias_list.append(pbias)
                kge_list.append(kge)
                flv_list.append(flv)
                fhv_list.append(fhv)
                rows.append([station, nse, rmse, mae, pbias, kge, flv, fhv, mse, mape, mspe, r2])
        avg_mae = np.mean(mae_list)
        avg_mse = np.mean(mse_list)
        avg_rmse = np.mean(rmse_list)
        avg_mape = np.mean(mape_list)
        avg_mspe = np.mean(nse_list)
        avg_r2 = np.mean(r2_list)
        avg_nse = np.mean(nse_list)
        avg_pbias = np.mean(pbias_list)
        avg_kge = np.mean(kge_list)
        avg_flv = np.mean(flv_list)
        avg_fhv = np.mean(fhv_list)
        median_mae = np.median(mae_list)
        median_mse = np.median(mse_list)
        median_rmse = np.median(rmse_list)
        median_mape = np.median(mape_list)
        median_mspe = np.median(nse_list)
        median_r2 = np.median(r2_list)
        median_nse = np.median(nse_list)
        median_pbias = np.median(pbias_list)
        median_kge = np.median(kge_list)
        median_flv = np.median(flv_list)
        median_fhv = np.median(fhv_list)
        rows.append(['Mean', avg_nse, avg_rmse, avg_mae, avg_pbias, avg_kge, avg_flv, avg_fhv, avg_mse, avg_mape, avg_mspe, avg_r2])
        rows.append(['Median', median_nse, median_rmse, median_mae, median_pbias, median_kge, median_flv, median_fhv, median_mse, median_mape, median_mspe, median_r2])
        with open(os.path.join(self._csv_save_path, f"{file_save_name}_{self._target_station}.csv"), "w", newline='') as f:
            writer = csv.writer(f)
            for row in rows:
                writer.writerow(row)

    def collect_stations_rmse_results(self, file_save_name=None):
        if file_save_name is None:
            if self._run_type == 'initial':
                file_save_name = f"station_rmse_{self._model}_{self._run_type}"
            else:
                if self._args.other_station in ['flow', 'st_flow']:
                    file_save_name = f"station_rmse_{self._model}_{self._run_type}_alpha{self._args.alpha}_{self._args.other_station}_{self._args.base_flow}"
                else:
                    file_save_name = f"station_rmse_{self._model}_{self._run_type}_alpha{self._args.alpha}_{self._args.other_station}"
        rows = []
        for station in self._all_stations:
            station_results_path = self._get_station_path(station)
            phase_folder_path = os.path.join(current_dir, '..', self._test_results_path, station_results_path, self._run_type, 'global')
            for root, station_folder, _ in os.walk(str(phase_folder_path)):
                for station2 in station_folder:
                    station_folder_path = os.path.join(str(phase_folder_path), station2)
                    for root2, _, files in os.walk(str(station_folder_path)):
                        for file in files:
                            if file.startswith("rmse_"):
                                file_path = os.path.join(str(phase_folder_path), station2, file)
                                with open(file_path, 'r') as f:
                                    lines = f.readlines()
                                    for line in lines:
                                        if 'RMSE' in line:
                                            rmse = line.split('RMSE ')[1]
                                            rows.append([station, station2, rmse])
                                        else:
                                            continue
            with open(os.path.join(self._csv_save_path, file_save_name), "w", newline='') as f:
                writer = csv.writer(f)
                for row in rows:
                    writer.writerow(row)


if __name__ == '__main__':
    import argparse
    import yaml
    parser = argparse.ArgumentParser(description='Water Flow Prediction')
    base_configs = os.path.join(current_dir, '../', 'configs/Water.Level/base_configs1.yaml')
    # load YAML configs
    with open(base_configs, 'r') as f:
        yaml_config = yaml.safe_load(f)

    parser.set_defaults(**yaml_config)  # load YAML as default
    # load num_test, num_vali
    data_time_dir = ('../' + parser.parse_args().dataset_path + parser.parse_args().target + '/'
                     + parser.parse_args().data_time_path)
    # data-to-time dir
    parser.set_defaults(data_time_dir=data_time_dir)
    args = parser.parse_args()
    args.verbose = 1
    args.model = 'DMT'
    args.other_station = 'lag_correlation'
    results_path = os.path.join('..', 'results')

    args.alpha = 0.7
    setting = f'{args.model}__sl24_ll0_pl6_dm512_nh8_el1_dl1_df2048_fc1_ebtimeF_dtTrue_alpha{args.alpha}_exp'
    make_table = MakeTable(args, args.model, results_path, setting)

    make_table.save_to_csv()