import os
import torch

from scipy.fft import next_fast_len
from scipy.signal import csd
from sklearn.manifold import SpectralEmbedding
import copy
from tqdm import tqdm
import pandas as pd

from utils.graph_utils import cal_adj, cal_lape, compute_clustering_coefficient

import numpy as np
from sklearn.decomposition import PCA


def build_aux_data_from_all(subnodes_set, target_node_set, remain_idx, adj, full_node_period_feats, full_node_delay_matrix, full_node_csd_matrix, cal_lap_emb_dim, device, static_type, period_type, adj_type='auto'):
    N_subset = len(subnodes_set)
    aux_data = {}
    Pp = 'Pp' in static_type
    lap = 'lap' in static_type
    Sl, Ll = 'Sl' in static_type, 'Ll' in static_type
    
    emb_dim = 0

    if full_node_period_feats is None:
        batch_period_feats = None
    else:
        batch_period_feats = np.zeros((N_subset, full_node_period_feats.shape[-1])) #N,pca_dim
        batch_period_feats[remain_idx] = full_node_period_feats[target_node_set]
        if Pp == False:
            batch_period_feats = None

    aux_data['period_feats']= torch.FloatTensor(batch_period_feats).to(device) if batch_period_feats is not None else None
    
    period_feat_dim = batch_period_feats.shape[-1] if batch_period_feats is not None else 0
    emb_dim += period_feat_dim

    batch_delay_feats = np.zeros((N_subset, cal_lap_emb_dim)) #n1,lap_dim
    delay_ready = full_node_delay_matrix[target_node_set,:][:,target_node_set]
    delay_lap = get_lap_feats(delay_ready, cal_lap_emb_dim)
    batch_delay_feats[remain_idx] = delay_lap

    batch_csd_feats = np.zeros((N_subset, cal_lap_emb_dim)) #n1,lap_dim
    csd_ready = full_node_csd_matrix[target_node_set,:][:,target_node_set]
    csd_lap = get_lap_feats(csd_ready,cal_lap_emb_dim)
    batch_csd_feats[remain_idx] = csd_lap
    
    if Sl and Ll:
        stci_feats= np.concatenate([batch_csd_feats, batch_delay_feats], axis=-1)
    elif Sl==False and Ll:
        stci_feats = batch_delay_feats
    elif Sl and Ll== False:
        stci_feats = batch_csd_feats
    else:
        stci_feats = None
        
    aux_data['stci_feats']=torch.FloatTensor(stci_feats).to(device) if stci_feats is not None else None
    stci_feat_dim = stci_feats.shape[-1] if stci_feats is not None else 0
    emb_dim += stci_feat_dim

    known_adj = adj[subnodes_set,:][:,subnodes_set]
    lap_mx = cal_lape(copy.deepcopy(known_adj).astype(float), cal_lap_emb_dim)
    adj_mx = cal_adj(copy.deepcopy(known_adj).astype(float), adj_type)
    local_cc = compute_clustering_coefficient(copy.deepcopy(known_adj).astype(float))
   

    ### Topo Priors in paper
    if lap:
        structure_feats = lap_mx
    else:
        structure_feats = None

    aux_data['structure_feats'] = torch.FloatTensor(structure_feats).to(device) if structure_feats is not None else None
    structure_feat_dim = structure_feats.shape[-1] if structure_feats is not None else 0
    emb_dim += structure_feat_dim
    aux_data['adj_mx'] = [torch.FloatTensor(a).to(device) for a in adj_mx]

    return aux_data, emb_dim, [period_feat_dim, stci_feat_dim, structure_feat_dim]


def randomize_priors_np(Z_prior, mode="highVar", seed=42, match_scale=False):
   
    rng = np.random.default_rng(seed)
    N, d1 = Z_prior.shape

    sigma_H, sigma_L = 1.0, 0.0001
    sigma = sigma_H if mode == "highVar" else sigma_L

    Z_rand = rng.normal(loc=0.0, scale=sigma, size=(N, d1)).astype(Z_prior.dtype)

    if match_scale:
        mu_col = Z_prior.mean(axis=0, keepdims=True)
        std_col = Z_prior.std(axis=0, keepdims=True) + 1e-8  
        Z_rand = (Z_rand - Z_rand.mean(axis=0, keepdims=True)) / (Z_rand.std(axis=0, keepdims=True) + 1e-8)
        Z_rand = Z_rand * std_col + mu_col

    return Z_rand

def build_aux_data_from_stage1(stage2_data, stage1_aux_data, stage1_node_idx, stage2_node_idx, stage2_node_csd_matrix, device, topk=3, eps=1e-5, fuse_type ='stci'):
    stage1_period_feats = None if stage1_aux_data['period_feats'] is None  else stage1_aux_data['period_feats'].cpu().numpy()
    stage1_stci_feats = None if stage1_aux_data['stci_feats'] is None else stage1_aux_data['stci_feats'].cpu().numpy()
    stage1_structure_feats = None if stage1_aux_data['structure_feats'] is None else stage1_aux_data['structure_feats'].cpu().numpy()
    L, N2 = stage2_data.shape

    period_stage2 = None if stage1_period_feats is None else np.zeros((N2, stage1_period_feats.shape[-1]))
    stci_stage2 = None if stage1_stci_feats is None else np.zeros((N2, stage1_stci_feats.shape[-1]))
    structure_stage2 = None if stage1_structure_feats is None else np.zeros((N2, stage1_structure_feats.shape[-1]))

    idx_map = {node: i for i, node in enumerate(stage1_node_idx)}

    old_mask = np.isin(stage2_node_idx, stage1_node_idx)
    new_mask = ~old_mask

    for i in range(N2):
        node = stage2_node_idx[i]
        if node in idx_map:
            if stage1_period_feats is not None:
                period_stage2[i] = stage1_period_feats[idx_map[node]]
            if stage1_stci_feats is not None:
                stci_stage2[i] = stage1_stci_feats[idx_map[node]]
            if stage1_structure_feats is not None:
                structure_stage2[i] = stage1_structure_feats[idx_map[node]]
                
    if np.any(new_mask):
        B_old = stage2_data[:, old_mask]  # (L2, N_old)
        
        all_corrs = stage2_node_csd_matrix[new_mask,:][:,old_mask]
        for i_new, b_idx in enumerate(np.where(new_mask)[0]):
            b_seq = stage2_data[:, b_idx]  

            if fuse_type == 'corr':
                corrs = []
                for old_node_idx in range(B_old.shape[1]):
                    old_seq = B_old[:,old_node_idx]
                    corr = np.corrcoef(b_seq, old_seq)[0, 1]
                    if np.isnan(corr):
                        corr = 0.0
                    corrs.append(corr)
                corrs = np.array(corrs)
            else:
                corrs = all_corrs[i_new]
            
            topk_idx = np.argsort(np.abs(corrs))[-topk:]
            topk_corr = corrs[topk_idx]

            weights = np.abs(topk_corr)
            if weights.sum() < eps:
                weights = np.ones_like(weights) / len(weights)
            else:
                weights = weights / weights.sum()

            if stage1_period_feats is not None:
                topk_feats = stage1_period_feats[topk_idx]
                period_stage2[b_idx] = np.sum(weights[:, None] * topk_feats, axis=0)
            if stage1_stci_feats is not None:
                topk_stci_feats = stage1_stci_feats[topk_idx]
                stci_stage2[b_idx] = np.sum(weights[:, None] * topk_stci_feats, axis=0)
            if stage1_structure_feats is not None:
                topk_structure_feats = stage1_structure_feats[topk_idx]
                structure_stage2[b_idx] = np.sum(weights[:, None] * topk_structure_feats, axis=0)
    
    stage2_period_feats = torch.FloatTensor(period_stage2).to(device) if stage1_period_feats is not None else None
    stage2_stci_feats = torch.FloatTensor(stci_stage2).to(device) if stage1_stci_feats is not None else None
    stage2_structure_feats = torch.FloatTensor(structure_stage2).to(device) if stage1_structure_feats is not None else None
    return stage2_period_feats, stage2_stci_feats, stage2_structure_feats




###  For Period Priors
def compute_period_embeddings_by_period(data, period, period_pca_emb_dim, target_node_idx=None, norm_type='none'):
    """
    input Z: shape (D, N, T), D periods, N nodes, T steps per period
    return E_period: shape (D, N, n_components)
    """
    L, N = data.shape
    D = L // period  # 
    Z_truncated = data[:D * period]  # shape: (D*T, N)
    Z = Z_truncated.reshape(D, period, N)  # (D, T, N)
    
    Z = Z.transpose((0,2,1)) # ->(D,N,T)


    if norm_type in ['S','S-T']:
        node_mean = Z.mean(axis=-1,keepdims=True) # (D,N,1)
        node_std = Z.std(axis=-1,keepdims=True) + 1e-8 #(D,N,1)
        if (node_std == 0).any():
            print('std has zero!')
        Z = (Z-node_mean)/node_std

    elif norm_type == 'T':
        series_mean = Z.mean(axis=1,keepdims=True) #(D,1,T)
        series_std = Z.std(axis=1,keepdims=True) #(D,1,T)
        Z = (Z-series_mean)/series_std

    elif norm_type == 'ST':
        node_series_mean = Z.mean(axis=1, keepdims=True).mean(axis=-1,keepdims=True) #(D,1,1)
        node_series_std = Z.std(axis=1, keepdims=True).std(axis=-1,keepdims=True) #(D,1,1)
        Z = (Z-node_series_mean)/node_series_std

    if norm_type == 'S-T':
        series_mean = Z.mean(axis=1,keepdims=True) #(D,1,T)
        series_std = Z.std(axis=1,keepdims=True) #(D,1,T)
        Z = (Z-series_mean)/series_std

    target_node_idx = target_node_idx if target_node_idx is not None else np.arange(N)
    Z_target = Z[:,target_node_idx,:]
    
    print('Z_target.shape',Z_target.shape)
    PCA_period_list = np.zeros((D, len(target_node_idx), period_pca_emb_dim))
    for d in tqdm(range(D)):
        X = Z_target[d]  
        pca = PCA(svd_solver='full', n_components=period_pca_emb_dim)
        PCA_period_list[d] = pca.fit_transform(X)  # (N, n_components)
        # print(f"Day {d}: explained variance ratio = {pca.explained_variance_ratio_}")
    PCA_period = PCA_period_list.mean(axis=0) # (d, N, capability) -> (N, capability)
    return PCA_period


def compute_intraSeries_embedding(data, period=None, period_pca_emb_dim=24, topk=2, target_node_idx=None, norm_type='none'):
    E_node_list =  []
    for d in period:
        E_node = compute_period_embeddings_by_period(data, d, period_pca_emb_dim, target_node_idx, norm_type=norm_type)
        
        E_node_list.append(E_node)
    E_node= np.concatenate(E_node_list, axis=-1) # N, len(period)*d

    return E_node

def get_full_week_indices(time_df, ignore_full_weeks=False):
    time_df = time_df.copy()
    time_df = pd.DatetimeIndex(time_df['date'])
    delta = (time_df[1] - time_df[0]).total_seconds() / 60
    freq_minutes = int(round(delta))
    print('auto freq_minutes:',freq_minutes)

    slices_per_day = int(24 * 60 / freq_minutes)
    slices_per_week = 7 * slices_per_day

    # find first "Monday 00:00"
    for start_idx in range(len(time_df)):
        ts = time_df[start_idx]
        if ts.weekday() == 0 and ts.hour == 0 and ts.minute == 0:
            break
    else:
        raise ValueError('there is no start point of a full week!')

    remaining = len(time_df) - start_idx
    num_weeks = remaining // slices_per_week
    if num_weeks<=0:
        if ignore_full_weeks:
            return None
        else:
            raise ValueError('no full week remained for calculation!')
    end_idx = start_idx + num_weeks * slices_per_week
    print('start_idx:',start_idx)
    print('end_idx:',end_idx)
    return list(range(start_idx, end_idx))


def get_full_day_indices(time_df):
    time_df = time_df.copy()
    time_df = pd.DatetimeIndex(time_df['date'])
    
    delta = (time_df[1] - time_df[0]).total_seconds() / 60
    freq_minutes = int(round(delta))
    print('auto freq_minutes:',freq_minutes)

    slices_per_day = int(24 * 60 / freq_minutes)

    # find first "00:00"
    for start_idx in range(len(time_df)):
        ts = time_df[start_idx]
        if ts.hour == 0 and ts.minute == 0:
            break
    else:
        raise ValueError('there is no start point of a full day!')
    
    remaining = len(time_df) - start_idx
    num_days = remaining // slices_per_day
    if num_days<=0:
        raise ValueError('no full day remained for calculation!')
    end_idx = start_idx + num_days * slices_per_day
    print('start_idx:',start_idx)
    print('end_idx:',end_idx)
    return list(range(start_idx, end_idx))

def load_init_period_features(dataset_name, cal_period_pca_emb_dim, slice_size_per_day, period_type, load_tip, full_train_data, df_time, target_node_idx = None, norm_type='none',ignore_full_weeks=False, usePPCA=False,last_part_ratio=None):
    period = None
    topk = None
    assert period_type in ['day', 'week', 'day_week']
    if not os.path.exists(f'datasets/{dataset_name}/period_feats'):
        os.makedirs(f'datasets/{dataset_name}/period_feats')

    if period_type in ['day_week']:
        period = [slice_size_per_day, slice_size_per_day*7]
    elif period_type in ['day']:
        period = [slice_size_per_day]    
    elif period_type in ['week']:
        period = [slice_size_per_day*7]
    else:
        topk = int(period_type.split('fft_top')[-1])
    
    print('df_time.shape:',df_time.shape)
    print('full_train_data.shape:',full_train_data.shape)
    if 'week' in period_type:
        indices = get_full_week_indices(df_time, ignore_full_weeks)
    else:
        indices = get_full_day_indices(df_time)
    
    if indices is None:
        return None
    
    trimmed_train_data = full_train_data[indices]
    print('trimmed_train_data.shape:',trimmed_train_data.shape)

   
    full_tip = 'fullweeks' if 'week' in period_type else 'fulldays'
    if usePPCA:
        full_tip = f'{full_tip}_usePPCA'

    ppca_cache_path = f'datasets/{dataset_name}/period_feats/{dataset_name}_train_intraSeries_feats_{period_type}_ppca_dim{cal_period_pca_emb_dim}_{load_tip}_{full_tip}.npy'

    if os.path.exists(ppca_cache_path):
        period_pca_feats = np.load(ppca_cache_path)
        print('train_intraSeries_feats loaded!')
    else:
        print('creating train_intraSeries_feats ...')
        period_pca_feats = compute_intraSeries_embedding(trimmed_train_data, period, cal_period_pca_emb_dim, topk, target_node_idx=target_node_idx, norm_type=norm_type)
        np.save(ppca_cache_path, period_pca_feats)
        print(f'{ppca_cache_path} saved and loaded!')
    
    print('period_pca_feats.shape:',period_pca_feats.shape)
    period_feats = period_pca_feats # (N, topk*d)
    return period_feats

###  For Time-delayed Interaction Priors
def compute_max_corr_fft_fast(ori_data, window, noverlap, norm_type ='none'):
    L,N = ori_data.shape
    fs =1.0
    nperseg = window
    nfft = next_fast_len(2 * nperseg-1)
    
    delay_matrix = np.zeros((N, N), dtype=int)
    corr_matrix = np.zeros((N, N)) # max correlation value among different time delay 

    if norm_type == 'S':
        Z = ori_data
        node_mean = Z.mean(axis=-1,keepdims=True) # (D,N,1)
        node_std = Z.std(axis=-1,keepdims=True) + 1e-8 #(D,N,1)
        if (node_std == 0).any():
            print('std has zero!')
        Z = (Z-node_mean)/node_std
        data = Z
    elif norm_type == 'S-T':
        Z = ori_data
        node_mean = Z.mean(axis=-1,keepdims=True) # (D,N,1)
        node_std = Z.std(axis=-1,keepdims=True) + 1e-8 #(D,N,1)
        if (node_std == 0).any():
            print('std has zero!')
        Z = (Z-node_mean) / node_std
        time_mean = Z.mean(axis=1, keepdims=True)
        time_std = Z.std(axis=1, keepdims=True)
        Z = (Z-time_mean) / time_std
        data = Z

    elif norm_type == 'none':
        data = ori_data

    for i in tqdm(range(N)):
        for j in range(i+1, N):  # 

            f, Pxy = csd(data[:,i], data[:,j], fs=fs, window='hann', nperseg=nperseg, noverlap=noverlap, nfft = nfft)
        
            # cross-correlation
            cc = np.fft.ifft(Pxy).real
            cc = np.fft.fftshift(cc)
            d_vals = np.arange(-len(cc)//2, len(cc)//2)
            cc_norm = cc

            # find the max values
            max_idx = np.argmax(np.abs(cc_norm))
            max_delay = d_vals[max_idx]
            max_corr = cc_norm[max_idx]
                
            abs_delay = abs(int(max_delay))
            abs_corr = abs(max_corr)
                            
            delay_matrix[i, j] = abs_delay
            corr_matrix[i, j] = abs_corr
                
            delay_matrix[j, i] = delay_matrix[i, j]
            corr_matrix[j, i] = corr_matrix[i, j]
    
    return delay_matrix, corr_matrix

def get_lap_feats(data, n_components):
    embedder = SpectralEmbedding(n_components=n_components)
    lap_feats = embedder.fit_transform(data)
    return lap_feats


def load_init_stci_matrix(dataset_name, cal_lap_emb_dim, input_length, load_tip=None, full_train_data=None, norm_type='none'):
    if not os.path.exists(f'datasets/{dataset_name}/stci_feats'):
        os.makedirs(f'datasets/{dataset_name}/stci_feats')
    delay_matrix_cache_path = f'./datasets/{dataset_name}/stci_feats/{dataset_name}_train_stci_delay_mx_{cal_lap_emb_dim}_fast_{load_tip}.npy'
    corr_matrix_cache_path = f'./datasets/{dataset_name}/stci_feats/{dataset_name}_train_stci_corr_mx_{cal_lap_emb_dim}_fast_{load_tip}.npy'
    if os.path.exists(delay_matrix_cache_path) and os.path.exists(corr_matrix_cache_path):
        delay_matrix = np.load(delay_matrix_cache_path)
        corr_matrix = np.load(corr_matrix_cache_path)
    else:
        delay_matrix, corr_matrix = compute_max_corr_fft_fast(full_train_data, input_length, noverlap=0, norm_type=norm_type)
        np.save(delay_matrix_cache_path, delay_matrix)
        np.save(corr_matrix_cache_path, corr_matrix)
    print('delay_matrix.shape:',delay_matrix.shape)
    print('corr_matrix.shape:',corr_matrix.shape)
    return delay_matrix, corr_matrix

