import os
import pandas as pd
import yaml
from utils.constants_LamaH import Constants

current_dir = os.path.dirname(os.path.abspath(__file__))


class StationSpecSearchHyperparam:
    def __init__(self, args, station, metric='nse'):
        self._args = args
        self._target_station = str(station)
        self._metric = metric
        self._constants = Constants(args)
        if self._args.other_station in ['flow', 'st_flow']:
            graph_type = self._args.other_station + '_' + self._args.base_flow
        else:
            graph_type = self._args.other_station
        self._results_dir = os.path.join(current_dir, 'station_results_search', self._args.data, self._args.target,
                                         f'pl{self._args.pred_len}', self._args.model,
                                         self._args.run_type, graph_type, self._target_station)
        self._results_file_name = 'results.csv'
        self._yaml_path = os.path.join(current_dir, 'station_results_search', self._args.data, self._args.target,
                                       f'pl{self._args.pred_len}',
                                       self._args.model, self._args.run_type, graph_type, 'best_station_param.yaml')
        self._best_results_path = os.path.join(current_dir, 'station_results_search', self._args.data,
                                               self._args.target, f'pl{self._args.pred_len}',
                                               self._args.model, self._args.run_type, graph_type, 'best_results.csv')

    def _read_best_results(self, metric):
        results_path = os.path.join(self._results_dir, self._results_file_name)
        df = pd.read_csv(results_path)
        if metric == 'nse':
            best_line = df[df[metric] == df[metric].max()]
        else:
            best_line = df[df[metric] == df[metric].min()]
        self.best_param = best_line.to_dict(orient='records')[0]

    def _write_to_yaml(self, metric):
        if not os.path.exists(self._yaml_path):
            best_yaml = {str(station): {} for station in self._constants.all_stations}
            best_yaml[self._target_station] = self.best_param
        else:
            with open(self._yaml_path, 'r') as f:
                best_yaml = yaml.safe_load(f)
            try:
                if metric == 'nse':
                    if best_yaml[self._target_station][metric] > self.best_param[metric]:
                        best_yaml[self._target_station] = self.best_param
                else:
                    if best_yaml[self._target_station][metric] < self.best_param[metric]:
                        best_yaml[self._target_station] = self.best_param
            except KeyError:
                best_yaml[self._target_station] = self.best_param

        with open(self._yaml_path, 'w') as f:
            yaml.dump(best_yaml, f)

    def save_all_to_yame(self, metric):
        if not os.path.exists(self._yaml_path):
            best_yaml = {str(station): {} for station in self._constants.all_stations}
            with open(self._yaml_path, 'w') as f:
                yaml.dump(best_yaml, f)

        with open(self._yaml_path, 'r') as f:
            best_yaml = yaml.safe_load(f)
            for station in self._constants.all_stations:
                station_str = str(station)
                try:
                    results_dir = os.path.join(current_dir, 'station_results_search', self._args.data,
                                               self._args.target,
                                               f'pl{self._args.pred_len}', self._args.model,
                                               self._args.run_type, self._args.other_station, station_str)

                    results_path = os.path.join(results_dir, self._results_file_name)
                    df = pd.read_csv(results_path)
                    if metric == 'nse':
                        best_line = df[df[metric] == df[metric].max()]
                    else:
                        best_line = df[df[metric] == df[metric].min()]
                    best_param = best_line.to_dict(orient='records')[0]
                    if metric == 'nse':
                        if best_yaml[station_str][metric] > best_param[metric]:
                            best_yaml[station_str] = best_param
                    else:
                        if best_yaml[station_str][metric] < best_param[metric]:
                            best_yaml[station_str] = best_param
                except KeyError:
                    best_yaml[station_str] = best_param
                except FileNotFoundError:
                    pass

        with open(self._yaml_path, 'w') as f:
            yaml.dump(best_yaml, f)

    def save_station_results(self, rmse, mae, mape, nse, loop):
        if not os.path.exists(self._results_dir):
            os.makedirs(self._results_dir)
        file_path = os.path.join(self._results_dir, 'results.csv')
        if not os.path.exists(file_path):
            with open(file_path, 'w') as f:
                f.write('rmse,mae,mape,nse,alpha,global_loop,global_lr,global_lr_factor\n')
                f.write(
                    f'{rmse},{mae},{mape},{nse},{self._args.alpha},{loop},{self._args.global_lr},{self._args.global_lr_factor}\n')
        else:
            with open(file_path, 'a') as f:
                f.write(
                    f'{rmse},{mae},{mape},{nse},{self._args.alpha},{loop},{self._args.global_lr},{self._args.global_lr_factor}\n')

    def save_best_results_to_csv(self):
        with open(self._yaml_path, 'r') as f:
            best_yaml = yaml.safe_load(f)
        rows = []
        sum_rmse = 0
        sum_mae = 0
        sum_mape = 0
        sum_nse = 0
        for station in self._constants.all_stations:
            station_str = str(station)
            try:
                rmse = best_yaml[station_str]['rmse']
                mae = best_yaml[station_str]['mae']
                mape = best_yaml[station_str]['mape']
                nse = best_yaml[station_str]['nse']
            except KeyError:
                rmse = 0
                mae = 0
                mape = 0
                nse = 0
            sum_rmse += rmse
            sum_mae += mae
            sum_mape += mape
            sum_nse += nse
            rows.append(f"{station_str},{rmse},{mae},{mape},{nse}")
        avg_rmse = sum_rmse / len(self._constants.all_stations)
        avg_mae = sum_mae / len(self._constants.all_stations)
        avg_mape = sum_mape / len(self._constants.all_stations)
        avg_nse = sum_nse / len(self._constants.all_stations)
        rows.append(f"Avg,{avg_rmse},{avg_mae},{avg_mape},{avg_nse}")
        with open(self._best_results_path, 'w') as f:
            f.write('\n'.join(rows))

    def save_best_station_param(self):
        self._read_best_results(self._metric)
        self._write_to_yaml(self._metric)


if __name__ == '__main__':
    import yaml
    import argparse

    parser = argparse.ArgumentParser(description='Water Flow Prediction')
    base_configs = os.path.join(current_dir, 'configs/LamaH_daily/pl1/base_configs.yaml')
    # load YAML configs
    with open(base_configs, 'r') as f:
        yaml_config = yaml.safe_load(f)
    parser.set_defaults(**yaml_config)  # load YAML as default

    args = parser.parse_args()
    args.model = 'DMT'
    args.run_type = 'FlowNet'
    args.other_station = 'lag_correlation'
    constants = Constants(args)
    # for station in constants.all_stations:
    #     try:
    #         station_search_param = StationSpecSearchHyperparam(args, station, metric='nse')
    #         station_search_param.save_best_station_param()
    #     except:
    #         continue
    station_search_param = StationSpecSearchHyperparam(args, constants.all_stations[0], metric='nse')
    station_search_param.save_all_to_yame(metric='nse')
    station_search_param.save_best_results_to_csv()
