from copy import deepcopy
import random
import pandas as pd
import os
import yaml
import json
# from geopy.distance import geodesic
import numpy as np
from sklearn.preprocessing import StandardScaler

current_dir = os.path.dirname(os.path.abspath(__file__))


class Constants:
    def __init__(self, args):
        self._MAIN_STATIONS_LIST = None
        self._MAIN_WATER_FLOW_LIST = None
        self._STATION_LOCATION_LIST = None
        self._station_channels_dict = None
        self._child_stations_dict = None
        self._parent_stations_dict = None
        self._adj_matrix = None
        self._distance_adj_matrix = None
        self._lag_correlation_dict = None
        self._edge_index, self._edge_attr = None, None
        self._basin_dict = None
        # constants
        self._args = args

        with open(os.path.join(current_dir, '..', self._args.dataset_path, 'other_data.yaml'),
                  'r', encoding='utf-8') as f:
            self._other_data = yaml.safe_load(f)
        if 'LamaH' in self._args.data:
            self.missing_graph_ratio = self._args.missing_graph_ratio
            self._original_adjacency = pd.read_csv(
                os.path.join(current_dir, '..', self._args.dataset_path, 'processed', 'adjacency_399_True.csv'))
            if self.missing_graph_ratio > 0.0:
                self.adjacency = pd.read_csv(
                    os.path.join(current_dir, '..', self._args.dataset_path, 'processed', f'adjacency_399_True_missing_{self.missing_graph_ratio}.csv'))
                # self.adjacency = self._original_adjacency
            else:
                self.adjacency = self._original_adjacency
            self._MAIN_STATIONS_LIST = self._original_adjacency['ID'].to_list()
            self._MAIN_STATIONS_LIST.append(399)
            self._MAIN_STATIONS_LIST = list(map(str, self._MAIN_STATIONS_LIST))
            self._child_stations_dict = self._get_child_stations_dict()
            self._MAIN_WATER_FLOW_LIST = self._child_stations_dict
            self._parent_stations_dict = self._get_parent_stations_dict()
            # self._lag_correlation_dict = self._get_lag_correlation_dict()
        elif self._args.data == 'MeKong':
            self._adj_matrix = self._get_adj_matrix()
            self._distance_adj_matrix = self._get_distance_adj_matrix()
        elif self._args.data == 'camels':
            self._MAIN_STATIONS_LIST = self._other_data['MAIN_STATIONS_LIST']
            self._adj_matrix = pd.read_csv(
                os.path.join(current_dir, '..', str(self._args.dataset_path), 'camels_knn_adj_2.csv'), index_col=0)
            self._adj_matrix.columns = self._adj_matrix.columns.map(lambda x: str(x).zfill(8))
            self._adj_matrix.index = self._adj_matrix.index.map(lambda x: str(x).zfill(8))
            self._child_stations_dict = self._get_child_stations_dict()
            self._MAIN_WATER_FLOW_LIST = self._child_stations_dict
            self._parent_stations_dict = self._get_parent_stations_dict()
            self._basin_dict = self._get_basin_dict()

    def _get_basin_dict(self):
        basin_list = [f"{i:02d}" for i in range(1, 19)]
        basin_dict = {}  # 创建字典存储结果
        for basin_i in basin_list:
            basin_path = os.path.join(current_dir, '../', self._args.dataset_path, self._args.data_root_path, basin_i)
            # 检查路径是否存在
            if not os.path.exists(basin_path):
                print(f"Warning: Path not found - {basin_path}")
                continue
            # 获取所有文件名（忽略子目录）
            filenames = [f for f in os.listdir(str(basin_path))
                         if os.path.isfile(os.path.join(str(basin_path), f))]
            # 提取前8字符并保存到字典
            station_list = [filename[:8] for filename in filenames if filename[:8] in self._MAIN_STATIONS_LIST]
            basin_dict[basin_i] = station_list
        return basin_dict

    def _get_child_stations_dict(self):
        child_stations_dict = {}
        if 'LamaH' in self._args.data:
            # all_samples = random.sample(self._MAIN_STATIONS_LIST, 126)
            # missing = random.sample(all_samples, 63)
            # inverse = [x for x in all_samples if x not in missing]
            # print(len(missing), missing)
            # print(len(inverse), inverse)
            # fake_list = []

            for station in self._MAIN_STATIONS_LIST:
                station_int = int(station)
                temp = self.adjacency[self.adjacency['ID'] == station_int]
                child_list = [str(x) for x in temp['NEXTDOWNID'].to_list()]
                child_stations_dict[station] = child_list

            #     fake_tmp = [x for x in self._MAIN_STATIONS_LIST if x not in child_list]
            #     fake = random.sample(fake_tmp, 1)
            #     fake_list += fake
            # print(len(fake_list), fake_list)
            # exit()
        elif self._args.data == 'camels':
            for station in self._MAIN_STATIONS_LIST:
                station_str = str(station)
                # 确保 station 在邻接矩阵中
                if station_str not in self._adj_matrix.index:
                    print(f"Warning: {station_str} not in adjacency.")
                    continue
                # 找出该 station 对应行的所有 1 的列（即邻居）
                neighbors = self._adj_matrix.loc[station_str]
                neighbor_ids = neighbors[neighbors == 1].index.tolist()
                # 如果是无向图，去掉自己
                neighbor_ids = [nid for nid in neighbor_ids if nid != station_str]
                # 保存到 dict
                child_stations_dict[station_str] = neighbor_ids
        return child_stations_dict

    def _get_parent_stations_dict(self):
        parent_stations_dict = {}
        for station in self._MAIN_STATIONS_LIST:
            station_str = str(station)
            parent_stations_dict[station_str] = []
            for parent_station, child_list in self._MAIN_WATER_FLOW_LIST.items():
                if station_str in child_list and parent_station not in self._MAIN_WATER_FLOW_LIST[station_str]:
                    parent_stations_dict[station_str].append(parent_station)
        return parent_stations_dict

    def _get_lag_correlation_dict(self):
        try:
            df = pd.read_csv(os.path.join(current_dir, '..', self._data_time_dir, 'shift_best_lag.csv'))
        except FileNotFoundError:
            if self._args.verbose:
                print('No shift_best_lag.csv file, return lag_correlation_dict None.')
            return None
        df = df.set_index('Unnamed: 0')
        df_cols = deepcopy(df.columns)
        assert sorted([int(x) for x in df_cols]) == self.all_stations, 'lag_correlation stations list is not match'
        lag_correlation_dict = {}
        for station1 in self.all_stations:
            station1_str = str(station1)
            lag_correlation_dict[station1_str] = []
            for station2 in self.all_stations:
                if station1 == station2:
                    continue
                station2_str = str(station2)
                if not np.isnan(df[station1_str][station2]):
                    lag_correlation_dict[station1_str].append(station2_str)
        return lag_correlation_dict

    def _get_adj_matrix(self):
        station_coords = self._STATION_LOCATION_LIST
        flow_structure = self._MAIN_WATER_FLOW_LIST

        stations = self.all_stations
        n = len(stations)
        station_to_idx = {station: i for i, station in enumerate(stations)}
        matrix = np.zeros((n, n))

        # Step 1: Compute sigma (std of all directed edge distances)
        distances = []
        for u in flow_structure:
            for v in flow_structure[u]:
                u_coord = (station_coords[u]['latitude'], station_coords[u]['longitude'])
                v_coord = (station_coords[v]['latitude'], station_coords[v]['longitude'])
                dist = geodesic(u_coord, v_coord).km
                distances.append(dist)

        sigma = np.std(distances)
        if self._args.verbose:
            print(f"Computed σ (standard deviation) = {sigma:.3f} km")
        temp = 0
        # Step 2: Build adjacency matrix using Gaussian formula
        for u in flow_structure.keys():
            if u not in station_to_idx:
                if self._args.verbose:
                    print(f"Station '{u}' not in location list. Skipping.")
                continue

            u_idx = station_to_idx[u]
            u_coord = (station_coords[u]['latitude'], station_coords[u]['longitude'])

            for v in flow_structure[u]:
                if v not in station_to_idx:
                    if self._args.verbose:
                        print(f"Station '{v}' not in location list. Skipping.")
                    continue

                v_idx = station_to_idx[v]
                v_coord = (station_coords[v]['latitude'], station_coords[v]['longitude'])

                try:
                    dist_uv = geodesic(u_coord, v_coord).km
                    weight = np.exp(-(dist_uv ** 2) / (sigma ** 2))
                    matrix[u_idx, v_idx] = weight
                    # if self._args.verbose:
                    #     print(u_idx, v_idx, weight)
                except Exception as e:
                    print(f"Could not compute edge {u} → {v}: {e}")
        # Step 3: Save to sparse format
        # adj_sparse = sparse.csr_matrix(matrix)
        return matrix

    def _get_distance_adj_matrix(self):
        file_path = os.path.join(current_dir, '../', self._args.dataset_path, 'stations_distance_adjacency_matrix.csv')
        df = pd.read_csv(file_path)
        df.set_index('Unnamed: 0', inplace=True)
        return df

    def get_other_list(self, station, other_station_type=None):
        if other_station_type is None:
            return []
        if other_station_type == 'flow':
            with open(os.path.join(current_dir, '..', self._args.data_time_dir,
                                   f'{self._args.model}_flow_dict_{self._args.base_flow}.json'),
                      'r', encoding='utf-8') as f:
                best_other_dict = json.load(f)
            other_list = best_other_dict[station]
        elif other_station_type == 'child':
            other_list = self._child_stations_dict[station]
        elif other_station_type == 'parent':
            other_list = self._parent_stations_dict[station]
        elif other_station_type == 'child_parent':
            other_list = self._child_stations_dict[station] + self._parent_stations_dict[station]
        elif other_station_type == 'lag_correlation':
            if self.lag_correlation_dict is None:
                if self._args.verbose:
                    print('No lag_correlation file, return other_list empty list.')
                other_list = []
            else:
                other_list = self.lag_correlation_dict[station]
        elif other_station_type == 'granger':
            if self.granger_links_dict is None:
                if self._args.verbose:
                    print('No granger_links file, return other_list empty list.')
                other_list = []
            else:
                other_list = self.granger_links_dict[station]
        else:
            other_list = []
        return other_list

    def _get_edge_index(self):
        station_coords = self._STATION_LOCATION_LIST
        id_dict = {station: i for i, station in enumerate(self.all_stations)}
        edge_index = [[], []]
        edge_attr = []
        for station in self.all_stations:
            station_loc = station_coords[station]['latitude'], station_coords[station]['longitude']
            other_list = self.get_other_list(station, other_station_type='child')
            for station2 in other_list:
                station2_loc = station_coords[station2]['latitude'], station_coords[station2]['longitude']
                i = id_dict[station]
                j = id_dict[station2]
                edge_index[0].append(i)
                edge_index[1].append(j)

                dis = geodesic(station_loc, station2_loc).km
                edge_attr.append(dis)
        return np.array(edge_index), np.array(edge_attr)

    @property
    def station_channels_dict(self):
        return self._station_channels_dict

    @property
    def child_stations_dict(self):
        return self._child_stations_dict

    @property
    def parent_stations_dict(self):
        return self._parent_stations_dict

    @property
    def all_stations(self):
        return sorted(self._MAIN_STATIONS_LIST)

    @property
    def adj_matrix(self):
        return self._adj_matrix

    @property
    def distance_adj_matrix(self):
        return self._distance_adj_matrix

    @property
    def lag_correlation_dict(self):
        return self._lag_correlation_dict

    @property
    def edge_index(self):
        return self._edge_index

    @property
    def edge_attr(self):
        return self._edge_attr

    @property
    def basin_dict(self):
        return self._basin_dict


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Water Flow Prediction')
    base_configs = os.path.join(current_dir, '../', 'configs/LamaH/base_configs1.yaml')
    # load YAML configs
    with open(base_configs, 'r') as f:
        yaml_config = yaml.safe_load(f)

    parser.set_defaults(**yaml_config)  # load YAML as default
    # load num_test, num_vali
    data_time_dir = (parser.parse_args().dataset_path + parser.parse_args().target + '/'
                     + parser.parse_args().data_time_path)
    # data-to-time dir
    parser.set_defaults(data_time_dir=data_time_dir)
    args = parser.parse_args()
    args.verbose = 1
    args.missing_graph_ratio = 0.1
    args.model = 'DMT'
    args.base_flow = 'lag_correlation'

    constant = Constants(args)

