import os
import sys
import numpy as np
import pandas as pd

from lib import datasets_path
from .pd_dataset import PandasDataset
from ..utils import sample_mask, adj_keep_idx
from ..utils.utils import disjoint_months, infer_mask, compute_mean, geographical_distance, thresholded_gaussian_kernel


class NrelMd(PandasDataset):
    """
    NREL-Md
    Maryland
    80
    with 0, no nan
    """
    def __init__(self, adj_thr=0.1):
        files_info = pd.read_pickle("datasets/nrel_md/nrel_file_infos.pkl")
        capacities = np.array(files_info['capacity'])
        capacities = capacities.astype('float32')
        capacities = np.expand_dims(capacities, axis=(0, -1))
        df, dist, mask = self.load()
        self.dist = dist
        adj = self.get_similarity(thr=adj_thr)
        np.fill_diagonal(adj, 0.)
        keep_idx = adj_keep_idx(adj)
        self.keep_idx = keep_idx
        self.adj = adj[keep_idx, :][:, keep_idx]
        df = df.iloc[:, keep_idx]
        mask = mask[:, keep_idx]
        self.capacities = capacities[:, keep_idx, :]

        # Random shuffle of columns
        cols = np.arange(df.shape[1])
        np.random.shuffle(cols)
        self.adj = self.adj[cols, :][:, cols]
        df = df.iloc[:, cols]
        mask = mask[:, cols]
        self.cols = cols

        self.min_val = np.zeros_like(self.capacities)

        super().__init__(dataframe=df, u=None, mask=mask, name="nrel_md", freq='5T', aggr='nearest')

    def load(self, impute_zeros=True):
        path = os.path.join(datasets_path["nrel_md"], 'nrel_X_wrapped.csv')
        df = pd.read_csv(path, index_col="timestamps")
        datetime_idx = sorted(df.index)
        date_range = pd.date_range(datetime_idx[0], datetime_idx[-1], freq='5T')
        df.index = date_range
        df = df.replace(0, np.nan)
        mask = ~np.isnan(df.values)
        df = df.replace(np.nan, 0)
        dist = self.load_distance_matrix(list(df.columns))
        return df.astype('float32'), dist, mask.astype('uint8')

    def load_distance_matrix(self, ids):
        stations = pd.read_csv(os.path.join(datasets_path["nrel_md"], "nrel_file_infos.csv"))
        # compute distances from latitude and longitude degrees
        dist_path = os.path.join(datasets_path['nrel_md'], 'nrel_md_dist.npy')
        try:
            dist = np.load(dist_path)
        except:
            st_coord = stations.loc[:, ['latitude', 'longitude']]
            dist = geographical_distance(st_coord, to_rad=True).values
            np.save(dist_path, dist)
        return dist

    def get_similarity(self, type='dcrnn', thr=0.1, include_self=False, force_symmetric=False, sparse=False):
        # Use passed thr parameter instead of fixed value, ensure consistency with nrel_al
        thr = 0.75
        theta = np.std(self.dist)  # use same theta for both air and air36
        adj = thresholded_gaussian_kernel(self.dist, theta=theta, threshold=thr)
        if not include_self:
            adj[np.diag_indices_from(adj)] = 0.
        if force_symmetric:
            adj = np.maximum.reduce([adj, adj.T])
        if sparse:
            import scipy.sparse as sps
            adj = sps.coo_matrix(adj)
        return adj

    @property
    def mask(self):
        if self._mask is None:
            return self.df.values != 0.
        return self._mask


class MissingValuesNrelMd(NrelMd):
    def __init__(self, p_fault=0.0015, p_noise=[0.025, 0.025], mode="random", test_entries="", eval_sample_strategy="random",
                 adj_thr=0.1):
        super(MissingValuesNrelMd, self).__init__(adj_thr=adj_thr)
        self.p_fault = p_fault
        # Ensure p_noise is in list format, maintain backward compatibility
        if isinstance(p_noise, (int, float)):
            p_noise = [p_noise/2, p_noise/2]  # Split single value into two halves for val and test
        self.p_noise = p_noise
        # data_path = "datasets/nrel_md/{}_{}.npy".format(mode, p_noise)
        # if os.path.exists(data_path):
        #     print("Load dataset...")
        #     test_mask = np.load(data_path)
        # else:
        #     print("Create dataset...")
        #     test_mask = sample_mask(self.numpy().shape,
        #                             p=p_fault,
        #                             p_noise=p_noise,
        #                             mode=mode)
        #     np.save(data_path, test_mask)
        if test_entries != "":
            with open("datasets/nrel_md/{}".format(test_entries), "r") as f:
                print("use test entries {}...".format(test_entries))
                entries = f.readlines()
                for i in range(len(entries)):
                    entries[i] = int(entries[i].replace("\n", ""))
                test_mask = np.zeros(self.numpy().shape).astype(np.bool)
                test_mask[:, entries] = True
        else:
            adj, pos = None, None
            if eval_sample_strategy == "region":
                pos = self.get_position()
            if eval_sample_strategy == 'degree':
                adj = self.get_similarity(thr=adj_thr)
                np.fill_diagonal(adj, 0.)

            # val_mask, test_mask = sample_mask(self.numpy().shape,
            #                         p=p_fault,
            #                         p_noise=p_noise,
            #                         mode=mode,
            #                         pos=pos,
            #                         adj=adj,
            #                         sample_strategy=eval_sample_strategy)
            num_nodes = self.numpy().shape[1]
            num_val, num_test = int(p_noise[0] * num_nodes), int(p_noise[1] * num_nodes)
            val_mask = np.zeros(self.numpy().shape).astype('uint8')
            val_mask[:, -num_val-num_test:-num_test] = 1
            test_mask = np.zeros(self.numpy().shape).astype('uint8')
            test_mask[:, -num_test:] = 1
        self.val_mask = (val_mask & self.mask).astype('uint8')
        self.test_mask = (test_mask & self.mask).astype('uint8')

    @property
    def training_mask(self):
        return self.mask if self.test_mask is None else (self.mask & (1 - self.val_mask - self.test_mask))

    def splitter(self, dataset, val_len=0, test_len=0, window=0):
        idx = np.arange(len(dataset))
        if test_len < 1:
            test_len = int(test_len * len(idx))
        if val_len < 1:
            val_len = int(val_len * (len(idx) - test_len))
        test_start = len(idx) - test_len
        val_start = test_start - val_len
        return [idx[:val_start - window], idx[val_start:test_start - window], idx[test_start:]]

    def get_position(self):
        # Load position information for NREL-MD dataset
        stations = pd.read_csv(os.path.join(datasets_path["nrel_md"], "nrel_file_infos.csv"))
        pos = stations.loc[:, ['latitude', 'longitude']]
        
        # z-score normalization
        pos = (pos - pos.mean()) / pos.std()
        pos = pos.values.astype(np.float32)
        pos = pos[self.keep_idx, :]
        # Reorder
        pos = pos[self.cols, :]
        return pos
