import pandas as pd
import os
import yaml
import json
from geopy.distance import geodesic
import numpy as np
from sklearn.preprocessing import StandardScaler
current_dir = os.path.dirname(os.path.abspath(__file__))


class Constants:
    def __init__(self, args):
        # constants
        self._args = args
        self._data_time_dir = self._args.dataset_path + self._args.target + '/' + self._args.data_time_path
        self._data_root_dir = (self._args.dataset_path + self._args.target + '/' + self._args.data_time_path
                               + self._args.data_root_path)

        with open(os.path.join(current_dir, '..', self._data_time_dir, 'other_data.yaml'),
                  'r', encoding='utf-8') as f:
            self._other_data = yaml.safe_load(f)
        self._MAIN_STATIONS_LIST = self._other_data['MAIN_STATIONS_LIST']
        self._MAIN_WATER_FLOW_LIST = self._other_data['MAIN_WATER_FLOW_LIST']
        self._STATION_LOCATION_LIST = self._other_data['STATION_LOCATION_LIST']
        self._station_channels_dict = self._get_station_channels_dict()
        self._child_stations_dict = self._get_child_stations_dict()
        self._parent_stations_dict = self._get_parent_stations_dict()
        self._adj_matrix = self._get_adj_matrix()
        self._distance_adj_matrix = self._get_distance_adj_matrix()
        self._lag_correlation_dict = self._get_lag_correlation_dict()
        self._granger_links_dict = self._get_granger_links_dict()
        self._edge_index, self._edge_attr = self._get_edge_index()
        self._randomize_dict = self._get_randomize_dict()

    def _get_randomize_dict(self):
        randomize_dict = {}
        for station in self.all_stations:
            randomize_dict[station] = []
            for station2 in self.all_stations:
                if station == station2:
                    continue
                if np.random.uniform(0, 1) < 0.5:
                    randomize_dict[station].append(station2)
        return randomize_dict

    def get_unified_scaler_dict(self):
        scaler_dict = {
            'Discharge.Daily': StandardScaler(),
            'Water.Level': StandardScaler(),
            'Rainfall.Manual': StandardScaler(),
        }
        for channel in scaler_dict.keys():
            df_all = pd.DataFrame()
            for station in self.all_stations:
                station_data_path = os.path.join(self._data_root_dir, f"{station}.csv")
                df_raw = pd.read_csv(station_data_path)
                num_train = len(df_raw) - self._args.num_vali - self._args.num_test
                df_train = df_raw[[channel]][:num_train].copy()
                df_all = pd.concat([df_all, df_train], axis=0, ignore_index=True)
            df_all = df_all.dropna()
            if df_all.empty:
                raise ValueError(f"No data found in channel {channel}")
            scaler_dict[channel].fit(df_all.values)
        return scaler_dict

    def _get_station_channels_dict(self):
        station_channels_dict = {}
        for station in self._MAIN_STATIONS_LIST:
            station_data_path = str(station) + '.csv'
            df_raw = pd.read_csv(os.path.join(current_dir, '..', self._data_root_dir, station_data_path))
            cols = list(df_raw.columns)
            cols.remove(self._args.target)
            cols.remove('Timestamp')
            df_raw = df_raw[['Timestamp'] + cols + [self._args.target]]
            # check NaN value and remove the col if it has
            cols_to_check = df_raw.columns[1:4]
            nan_cols = [col for col in cols_to_check if df_raw[col].isna().any()]
            df_raw = df_raw.drop(columns=nan_cols)
            df_raw = df_raw.drop(columns='Timestamp')
            assert len(df_raw.columns) >= 1, f"{station} data all columns have NaN value"
            station_channels_dict[station] = df_raw.columns.values.tolist()
        return station_channels_dict

    def _get_child_stations_dict(self):
        return self._MAIN_WATER_FLOW_LIST

    def _get_parent_stations_dict(self):
        parent_stations_dict = {}
        for station in self._MAIN_STATIONS_LIST:
            parent_stations_dict[station] = []
            for parent_station, child_list in self._MAIN_WATER_FLOW_LIST.items():
                if station in child_list:
                    parent_stations_dict[station].append(parent_station)
        return parent_stations_dict

    def _get_lag_correlation_dict(self):
        try:
            df = pd.read_csv(os.path.join(current_dir, '..', self._data_time_dir, 'shift_best_lag.csv'))
        except FileNotFoundError:
            if self._args.verbose:
                print('No shift_best_lag.csv file, return lag_correlation_dict None.')
            return None
        df = df.set_index('Unnamed: 0')
        assert sorted(df.columns) == self.all_stations, 'lag_correlation stations list is not match'
        lag_correlation_dict = {}
        for station1 in self.all_stations:
            lag_correlation_dict[station1] = []
            for station2 in self.all_stations:
                if station1 == station2:
                    continue
                if not np.isnan(df[station1][station2]):
                    lag_correlation_dict[station1].append(station2)
                if not np.isnan(df[station2][station1]) and station2 not in lag_correlation_dict[station1]:
                    lag_correlation_dict[station1].append(station2)

        main_lag_correlation_dict = {}
        # follow the main flow dict
        for station1 in lag_correlation_dict.keys():
            main_lag_correlation_dict[station1] = []
            for station2 in lag_correlation_dict[station1]:
                # station2 is station1 child
                if station2 in self._MAIN_WATER_FLOW_LIST[station1]:
                    main_lag_correlation_dict[station1].append(station2)
                # station2 is station1 parent
                elif station1 in self._MAIN_WATER_FLOW_LIST[station2]:
                    main_lag_correlation_dict[station1].append(station2)
        return main_lag_correlation_dict

    def _get_granger_links_dict(self):
        try:
            df = pd.read_csv(os.path.join(current_dir, '..', self._data_time_dir, 'granger_links_diff.csv'))
        except FileNotFoundError:
            if self._args.verbose:
                print('No granger_links_diff.csv file, return granger_links_dict None.')
            return None
        df = df.set_index('Unnamed: 0')
        assert sorted(df.columns) == self.all_stations, 'granger_links stations list is not match'
        granger_links_dict = {}
        for station1 in self.all_stations:
            granger_links_dict[station1] = []
            for station2 in self.all_stations:
                if station1 == station2:
                    continue
                if not np.isnan(df[station1][station2]):
                    granger_links_dict[station1].append(station2)
                if not np.isnan(df[station2][station1]) and station2 not in granger_links_dict[station1]:
                    granger_links_dict[station1].append(station2)
        return granger_links_dict

    def _get_adj_matrix(self):
        station_coords = self._STATION_LOCATION_LIST
        flow_structure = self._MAIN_WATER_FLOW_LIST

        stations = self.all_stations
        n = len(stations)
        station_to_idx = {station: i for i, station in enumerate(stations)}
        matrix = np.zeros((n, n))

        # Step 1: Compute sigma (std of all directed edge distances)
        distances = []
        for u in flow_structure:
            for v in flow_structure[u]:
                u_coord = (station_coords[u]['latitude'], station_coords[u]['longitude'])
                v_coord = (station_coords[v]['latitude'], station_coords[v]['longitude'])
                dist = geodesic(u_coord, v_coord).km
                distances.append(dist)

        sigma = np.std(distances)
        if self._args.verbose:
            print(f"Computed σ (standard deviation) = {sigma:.3f} km")
        temp = 0
        # Step 2: Build adjacency matrix using Gaussian formula
        for u in flow_structure.keys():
            if u not in station_to_idx:
                if self._args.verbose:
                    print(f"Station '{u}' not in location list. Skipping.")
                continue

            u_idx = station_to_idx[u]
            u_coord = (station_coords[u]['latitude'], station_coords[u]['longitude'])

            for v in flow_structure[u]:
                if v not in station_to_idx:
                    if self._args.verbose:
                        print(f"Station '{v}' not in location list. Skipping.")
                    continue

                v_idx = station_to_idx[v]
                v_coord = (station_coords[v]['latitude'], station_coords[v]['longitude'])

                try:
                    dist_uv = geodesic(u_coord, v_coord).km
                    weight = np.exp(-(dist_uv ** 2) / (sigma ** 2))
                    matrix[u_idx, v_idx] = weight
                    # if self._args.verbose:
                    #     print(u_idx, v_idx, weight)
                except Exception as e:
                    print(f"Could not compute edge {u} → {v}: {e}")
        # Step 3: Save to sparse format
        # adj_sparse = sparse.csr_matrix(matrix)
        return matrix

    def _get_distance_adj_matrix(self):
        file_path = os.path.join(current_dir, '../', self._args.dataset_path, 'stations_distance_adjacency_matrix.csv')
        df = pd.read_csv(file_path)
        df.set_index('Unnamed: 0', inplace=True)
        return df

    def get_other_list(self, station, other_station_type=None):
        if other_station_type is None:
            return []
        if other_station_type == 'st_flow':
            with open(os.path.join(current_dir, '..', self._args.data_time_dir, f'{self._args.model}_st_flow_dict_{self._args.base_flow}.json'),
                      'r', encoding='utf-8') as f:
                best_other_dict = json.load(f)
            self.seasonal_other_list = best_other_dict['seasonal'][station]
            self.trend_other_list = best_other_dict['trend'][station]
            other_list = list(dict.fromkeys(self.seasonal_other_list + self.trend_other_list))
        elif other_station_type == 'flow':
            with open(os.path.join(current_dir, '..', self._args.data_time_dir, f'{self._args.model}_flow_dict_{self._args.base_flow}.json'),
                      'r', encoding='utf-8') as f:
                best_other_dict = json.load(f)
            other_list = best_other_dict[station]
        elif other_station_type == 'child':
            other_list = self._child_stations_dict[station]
        elif other_station_type == 'parent':
            other_list = self._parent_stations_dict[station]
        elif other_station_type == 'child_parent':
            other_list = self._child_stations_dict[station] + self._parent_stations_dict[station]
        elif other_station_type == 'lag_correlation':
            if self.lag_correlation_dict is None:
                if self._args.verbose:
                    print('No lag_correlation file, return other_list empty list.')
                other_list = []
            else:
                other_list = self.lag_correlation_dict[station]
        elif other_station_type == 'granger':
            if self.granger_links_dict is None:
                if self._args.verbose:
                    print('No granger_links file, return other_list empty list.')
                other_list = []
            else:
                other_list = self.granger_links_dict[station]
        elif other_station_type == 'random':
            other_list = self._randomize_dict[station]
        else:
            other_list = []
        return other_list

    def _get_edge_index(self):
        station_coords = self._STATION_LOCATION_LIST
        id_dict = {station: i for i, station in enumerate(self.all_stations)}
        edge_index = [[], []]
        edge_attr = []
        for station in self.all_stations:
            station_loc = station_coords[station]['latitude'], station_coords[station]['longitude']
            other_list = self.get_other_list(station, other_station_type='child')
            for station2 in other_list:
                station2_loc = station_coords[station2]['latitude'], station_coords[station2]['longitude']
                i = id_dict[station]
                j = id_dict[station2]
                edge_index[0].append(i)
                edge_index[1].append(j)

                dis = geodesic(station_loc, station2_loc).km
                edge_attr.append(dis)
        return np.array(edge_index), np.array(edge_attr)

    @property
    def station_channels_dict(self):
        return self._station_channels_dict

    @property
    def child_stations_dict(self):
        return self._child_stations_dict

    @property
    def parent_stations_dict(self):
        return self._parent_stations_dict

    @property
    def all_stations(self):
        return sorted(self._MAIN_STATIONS_LIST)

    @property
    def adj_matrix(self):
        return self._adj_matrix

    @property
    def distance_adj_matrix(self):
        return self._distance_adj_matrix

    @property
    def lag_correlation_dict(self):
        return self._lag_correlation_dict

    @property
    def granger_links_dict(self):
        return self._granger_links_dict

    @property
    def edge_index(self):
        return self._edge_index

    @property
    def edge_attr(self):
        return self._edge_attr

    @property
    def randomize_dict(self):
        return self._randomize_dict


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Water Flow Prediction')
    base_configs = os.path.join(current_dir, '../', 'configs/Water.Level/base_configs.yaml')
    # load YAML configs
    with open(base_configs, 'r') as f:
        yaml_config = yaml.safe_load(f)

    parser.set_defaults(**yaml_config)  # load YAML as default
    # load num_test, num_vali
    data_time_dir = (parser.parse_args().dataset_path + parser.parse_args().target + '/'
                     + parser.parse_args().data_time_path)
    # data-to-time dir
    parser.set_defaults(data_time_dir=data_time_dir)
    args = parser.parse_args()
    args.verbose = 0
    args.model = 'DMT'
    args.base_flow = 'lag_correlation'

    constant = Constants(args)
    # for k, v in constant.child_stations_dict.items():
    #     print(k, v)
    print(constant.edge_attr.shape)
