import os
import pandas as pd
import yaml
from utils.constants import Constants
current_dir = os.path.dirname(os.path.abspath(__file__))


class StationSpecSearchHyperparam:
    def __init__(self, args, station, metric='rmse'):
        self._args = args
        self._target_station = station
        self._metric = metric
        self._constants = Constants(args)
        if self._args.other_station in ['flow', 'st_flow']:
            graph_type = self._args.other_station + '_' + self._args.base_flow
        else:
            graph_type = self._args.other_station
        self._results_dir = os.path.join(current_dir, 'station_results_search', self._args.target,
                                         f'{self._args.seq_len}_{self._args.pred_len}', self._args.model,
                                         self._args.phase, graph_type, station)
        self._results_file_name = 'results.csv'
        self._yaml_path = os.path.join(current_dir, 'station_results_search', self._args.target,
                                       f'{self._args.seq_len}_{self._args.pred_len}',
                                       self._args.model, self._args.phase, graph_type, 'best_station_param.yaml')
        self._best_results_path = os.path.join(current_dir, 'station_results_search', self._args.target,
                                               f'{self._args.seq_len}_{self._args.pred_len}',
                                               self._args.model, self._args.phase, graph_type, 'best_results.csv')

    def _read_best_results(self, metric):
        results_path = os.path.join(self._results_dir, self._results_file_name)
        df = pd.read_csv(results_path)
        if metric in ['nse', 'r2']:
            best_line = df[df[metric] == df[metric].max()]
        else:
            best_line = df[df[metric] == df[metric].min()]
        self.best_param = best_line.to_dict(orient='records')[0]

    def _write_to_yaml(self, metric):
        if not os.path.exists(self._yaml_path):
            best_yaml = {station: {} for station in self._constants.all_stations}
            best_yaml[self._target_station] = self.best_param
        else:
            with open(self._yaml_path, 'r') as f:
                best_yaml = yaml.safe_load(f)
            try:
                if metric in ['nse', 'r2']:
                    if best_yaml[self._target_station][metric] > self.best_param[metric]:
                        best_yaml[self._target_station] = self.best_param
                else:
                    if best_yaml[self._target_station][metric] < self.best_param[metric]:
                        best_yaml[self._target_station] = self.best_param
            except KeyError:
                best_yaml[self._target_station] = self.best_param

        with open(self._yaml_path, 'w') as f:
            yaml.dump(best_yaml, f)

    def save_station_results(self, rmse, mae, mape, nse, r2, loop):
        if not os.path.exists(self._results_dir):
            os.makedirs(self._results_dir)
        file_path = os.path.join(self._results_dir, 'results.csv')
        if not os.path.exists(file_path):
            with open(file_path, 'w') as f:
                f.write('rmse,mae,mape,nse,r2,alpha,global_loop,global_lr,global_lr_factor\n')
                f.write(
                    f'{rmse},{mae},{mape},{nse},{r2},{self._args.alpha},{loop},{self._args.global_lr},{self._args.global_lr_factor}\n')
        else:
            with open(file_path, 'a') as f:
                f.write(
                    f'{rmse},{mae},{mape},{nse},{r2},{self._args.alpha},{loop},{self._args.global_lr},{self._args.global_lr_factor}\n')

    def save_best_results_to_csv(self):
        with open(self._yaml_path, 'r') as f:
            best_yaml = yaml.safe_load(f)
        rows = []
        sum_rmse = 0
        sum_mae = 0
        sum_mape = 0
        sum_nse = 0
        sum_r2 = 0
        for station in self._constants.all_stations:
            try:
                rmse = best_yaml[station]['rmse']
                mae = best_yaml[station]['mae']
                mape = best_yaml[station]['mape']
                nse = best_yaml[station]['nse']
                r2 = best_yaml[station]['r2']
            except KeyError:
                rmse = 0
                mae = 0
                mape = 0
                nse = 0
                r2 = 0
            sum_rmse += rmse
            sum_mae += mae
            sum_mape += mape
            sum_nse += nse
            sum_r2 += r2
            rows.append(f"{station},{rmse},{mae},{mape},{nse},{r2}")
        avg_rmse = sum_rmse / len(self._constants.all_stations)
        avg_mae = sum_mae / len(self._constants.all_stations)
        avg_mape = sum_mape / len(self._constants.all_stations)
        avg_nse = sum_nse / len(self._constants.all_stations)
        avg_r2 = sum_nse / len(self._constants.all_stations)
        rows.append(f"Avg,{avg_rmse},{avg_mae},{avg_mape},{avg_nse},{avg_r2}")
        with open(self._best_results_path, 'w') as f:
            f.write('\n'.join(rows))

    def save_best_station_param(self):
        self._read_best_results(self._metric)
        self._write_to_yaml(self._metric)


if __name__ == '__main__':
    from main import parse_args
    args = parse_args()

    station_search_param = StationSpecSearchHyperparam(args, 'Ban Huai Khayuong', metric='rmse')
    station_search_param.save_best_station_param()
