from sklearn.metrics import mutual_info_score
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde, entropy
from sklearn.feature_selection import mutual_info_regression
from sklearn.metrics import mutual_info_score
from sklearn.neighbors import KernelDensity
from sklearn.metrics.cluster import normalized_mutual_info_score
from scipy.stats import spearmanr
from scipy.stats import pearsonr
import seaborn as sns
import matplotlib.pyplot as plt
import json
import os

from models.entity import Entity
from models.relationship import Relationship 
from .util import *


import pandas as pd
import numpy as np
np.random.seed(42)


def z_normalize(x, mean, std):
    return (x - np.mean(x)) / np.std(x)

def min_max_normalize(x):
    return (x - np.min(x)) / (np.max(x) - np.min(x))

class wag_retriever:
    def __init__(self, 
                 metric_to_check, 
                 rel_info, 
                 llm_prior_df, 
                 hyperparameter, 
                 ind_df=None, 
                 query_info=None,
                 ):
        self._rng = np.random.RandomState(42) 
        self.query_metric = metric_to_check
        self.rel_type = hyperparameter['rel_type']
        self.full_rel_only = hyperparameter['full_rel_only']# only consider metrics that have full relationship available, llm prior, pop, ind 
        self.update_rel_info(rel_info)

        self.llm_prior_df = llm_prior_df
        
        self.hyper = hyperparameter
        # if ind_df:

        self.ind_df = ind_df.copy()
        self.query_info = query_info
        self.res = self._initialize_result()

    def _initialize_result(self):
        res = pd.DataFrame(index=self.data_associated_metrics)

        res['llm_prior'] = self.llm_prior_df.loc[self.data_associated_metrics, self.query_metric]
   

        mean_prior = np.nan
        std_prior = np.nan
          
        valid_prior = res['llm_prior'].dropna()
        if not valid_prior.empty:
            mean_prior = valid_prior.mean()
            std_prior = valid_prior.std()

          
            
        if std_prior != 0 and not np.isnan(std_prior):
            res['llm_prior'] = (res['llm_prior'] - mean_prior) / std_prior
        #sigmoid transform llm_prior
        res['llm_prior'] = self._apply_sigmoid(res['llm_prior'])
        # print('llm_prior', res['llm_prior'])
        return res

    def _handle_abnormality(self):
        if self.ind_df is None:
            return
        used_cols = self.numeric_metrics + ['date']
        par_df = self.ind_df[used_cols].copy()
        anomaly = detect_anomalies(par_df, self.query_info['query_date'], self.query_info['time_granularity'])
        # print(anomaly)
        self.res['recent_abnormality'] = self.res.index.map(anomaly)
        # print(self.res['recent_abnormality'])
        if np.isnan(self.res['recent_abnormality']).all():
            self.res['weight_local'] = np.nan
        else:
            _, _, weight = blend_openness_abnormality(
                self.res['recent_abnormality'],
                self.query_info['openness'],
                self.hyper['t_local']
            )
            # print(weight)
            self.res['weight_local'] = weight
    
    def _calculate_mi_features(self):
        return None 

    def _calculate_correlation_features(self):
        if self.rel_type == 'spearman' or self.rel_type == 'pearson':
        # Population correlation and sample size
            if self.rel_pop_all is not None:
                self.res['rel_pop'] = self.rel_pop_all[self.query_metric].abs()
            else:
                self.res['rel_pop'] = np.nan

            #Fisher z-transformation
            if self.rel_pop_sample_size_all is not None:
                self.res['pop_sample_size'] = self.rel_pop_sample_size_all[self.query_metric]
                self.res['var_pop'] = 1 / np.sqrt(self.res['pop_sample_size'].clip(lower=4) - 3)
            else:
                self.res['pop_sample_size'] = np.nan
                self.res['var_pop'] = np.nan

            # Individual correlation and sample size
            if self.rel_ind_all is not None:
                self.res['rel_ind'] = self.rel_ind_all[self.query_metric].abs()
            else:
                self.res['rel_ind'] = np.nan

            if self.rel_ind_sample_size_all is not None:
                self.res['ind_sample_size'] = self.rel_ind_sample_size_all[self.query_metric]
                self.res['var_ind'] = 1 / np.sqrt(self.res['ind_sample_size'].clip(lower=4) - 3)
            else:
                self.res['ind_sample_size'] = np.nan
                self.res['var_ind'] = np.nan
            # print(self.res)
            # Filter based on minimum sample size
            if self.full_rel_only:
                # if the sample size is less than minimum sample size, remove the row
                self.res = self.res[self.res['pop_sample_size'] >= self.hyper['min_samples_needed']]
                self.res = self.res[self.res['ind_sample_size'] >= self.hyper['min_samples_needed']]
            else:
                self.res.loc[(self.res['pop_sample_size'] < self.hyper['min_samples_needed']) | 
                             (self.res['pop_sample_size'].isna()), 'rel_pop'] = np.nan
                self.res.loc[(self.res['ind_sample_size'] < self.hyper['min_samples_needed']) | 
                             (self.res['ind_sample_size'].isna()), 'rel_ind'] = np.nan

            
            
            # Apply Fisher transform and sigmoid        
            self.res['rel_pop'] = fisher_transform(self.res['rel_pop'])
            self.res['rel_ind'] = fisher_transform(self.res['rel_ind'])
            # normalize to [0,1]
      


            
            mean_pop = np.nan
            std_pop = np.nan
            mean_ind = np.nan
            std_ind = np.nan
            valid_pop = self.res['rel_pop'].dropna()
            if not valid_pop.empty:
                mean_pop = valid_pop.mean()
                std_pop = valid_pop.std()
            # # else:
            #     print("Skipping: all rel_pop values are NaN")
            
            valid_ind = self.res['rel_ind'].dropna()
            if not valid_ind.empty:
                # print(valid_ind)
                mean_ind = valid_ind.mean()
                std_ind = valid_ind.std()
          
            
            if std_pop != 0 and not np.isnan(std_pop):
                self.res['rel_pop'] = (self.res['rel_pop'] - mean_pop) / std_pop
           

            if std_ind != 0 and not np.isnan(std_ind):
                self.res['rel_ind'] = (self.res['rel_ind'] - mean_ind) / std_ind
     
            # Apply Fisher transform and sigmoid
            self.res['rel_pop'] = self._apply_sigmoid(self.res['rel_pop'])
            self.res['rel_ind'] = self._apply_sigmoid(self.res['rel_ind'])

            if std_pop != 0 and not np.isnan(std_pop):
                self.res['var_pop'] = self.res['var_pop']/std_pop
            # else:
            #     print("Skipping: std_pop is 0 or NaN")

            if std_ind != 0 and not np.isnan(std_ind):
                self.res['var_ind'] = self.res['var_ind']/std_ind
            # else:
            #     print("Skipping: std_ind is 0 or NaN")

            # self.res['var_pop'] = self.res['var_pop']/std_pop
            # self.res['var_ind'] = self.res['var_ind']/std_ind

            # Adjust standard errors
            self.res['var_pop'] *= self.res['rel_pop'] * (1 - self.res['rel_pop'])
            self.res['var_ind'] *= self.res['rel_ind'] * (1 - self.res['rel_ind'])

            #convert se to var
            self.res['var_pop'] = self.res['var_pop']**2
            self.res['var_ind'] = self.res['var_ind']**2
            # Estimate prior SE
            self.res['var_prior'] = self.res['var_pop'].dropna().mean()
            # self.res.fillna(1e-4, inplace=True)
            
        elif self.rel_type == 'mi':
            raise NotImplementedError("MI is not implemented yet")
        else:
            raise ValueError(f"Invalid relation type: {self.rel_type}")

            
            # Fill any remaining NaNs for posterior computation
            
    def _apply_sigmoid(self, x):
        return 1 / (1 + np.exp(-self.hyper['t_global'] * x))

    def _compute_pop_posterior(self):
        # print(self.res)

        llm_prior = self.res['llm_prior']
        llm_var = np.diag((self.hyper['alpha_prior'] * self.res['var_prior']))
        llm_var = np.nan_to_num(llm_var, nan=1e-8)
        pop_r = self.res['rel_pop']
       
        pop_r = np.nan_to_num(pop_r, nan=1e-8)
        pop_var = np.diag((self.hyper['alpha_pop'] * self.res['var_pop']))
        pop_var = np.nan_to_num(pop_var, nan=1e-8)
        mu_pop, sigma_pop = update_prior(llm_prior, llm_var, pop_r, pop_var)
    
        self.res['posterior_pop'] = mu_pop
        self.res['posterior_pop_var'] = np.diag(sigma_pop)

        self.res.loc[(self.res['pop_sample_size'] < self.hyper['min_samples_needed']) | 
                     (self.res['pop_sample_size'].isna()), 'posterior_pop'] = np.nan
        self.res.loc[(self.res['pop_sample_size'] < self.hyper['min_samples_needed']) | 
                     (self.res['pop_sample_size'].isna()), 'posterior_pop_var'] = np.nan
        self.res.loc[(self.res['pop_sample_size'] < self.hyper['min_samples_needed']) | 
                     (self.res['pop_sample_size'].isna()), 'rel_pop'] = np.nan
        # Store intermediate posterior for use in individual update
        self._mu_pop = mu_pop
        self._sigma_pop = sigma_pop

    #convert all values in df to 2 decimal places
    def _round_all_values(self, decimal_places=2):
        self.res = self.res.round(decimal_places)

    def _compute_ind_posterior(self):
        if not hasattr(self, '_mu_pop') or not hasattr(self, '_sigma_pop'):
            raise ValueError("Population posterior must be computed before individual posterior.")

        ind_r = self.res['rel_ind']
        ind_r = np.nan_to_num(ind_r, nan=1e-8)
        # print(ind_r)
        ind_var = np.diag((self.hyper['alpha_ind'] * self.res['var_ind']))
        ind_var = np.nan_to_num(ind_var, nan=1e-8)
        mu_ind, sigma_ind = update_prior(self._mu_pop, self._sigma_pop, ind_r, ind_var)

        self.res['posterior_ind'] = mu_ind
        self.res['posterior_ind_var'] = np.diag(sigma_ind)


        self.res.loc[(self.res['ind_sample_size'] < self.hyper['min_samples_needed']) | 
                     (self.res['ind_sample_size'].isna()), 'rel_ind'] = np.nan
        self.res.loc[(self.res['ind_sample_size'] < self.hyper['min_samples_needed']) | 
                     (self.res['ind_sample_size'].isna()), 'var_ind'] = np.nan
        self.res.loc[(self.res['ind_sample_size'] < self.hyper['min_samples_needed']) | 
                     (self.res['ind_sample_size'].isna()), 'posterior_ind'] = np.nan
        self.res.loc[(self.res['ind_sample_size'] < self.hyper['min_samples_needed']) | 
                     (self.res['ind_sample_size'].isna()), 'posterior_ind_var'] = np.nan


    def _handle_nan_posterior(self):
        # if posterior pop is nan, set it equal to llm_prior
        if 'posterior_pop' in self.res.columns:
            self.res.loc[pd.isna(self.res['posterior_pop']), 'posterior_pop'] = self.res.loc[pd.isna(self.res['posterior_pop']), 'llm_prior']
        if 'posterior_ind' in self.res.columns:
            self.res.loc[pd.isna(self.res['posterior_ind']), 'posterior_ind'] = self.res.loc[pd.isna(self.res['posterior_ind']), 'posterior_pop']
    

    def _final_weighting(self):
        
        # Determine global weight: use posterior_ind if available; otherwise fallback to posterior_pop; 

        # only interested in the data associated metrics
        self.res = self.res[self.res.index.isin(self.data_associated_metrics)]
        self.res['weight_global'] = self.res['posterior_ind'].where(
    
            ~self.res['posterior_ind'].isna(),
            self.res['posterior_pop'].where(
                ~self.res['posterior_pop'].isna(),
                self.res['llm_prior']
            )
        )
      

        if self.ind_df is not None:
            # if abnormality_weight is nan, use only posterior_ind
            self.res['weight_final'] = self.hyper['beta'] * self.res['weight_local'].fillna(0) + \
                                    (1 - self.hyper['beta']) * self.res['weight_global']
            self.res.loc[self.res['weight_local'].isna(), 'weight_final'] = \
                self.res.loc[self.res['weight_local'].isna(), 'weight_global']


    def _remove_query_metric(self):
        self.res = self.res[self.res.index != self.query_metric]

    def update_rel_info(self, rel_info):
        self.rel_pop_all = rel_info['rel_pop_all']
        if self.rel_pop_all is not None:
            #all the values in the matrix except the diagonal
            rel_pop_vals = self.rel_pop_all.values
            # Create a boolean mask with False on the diagonal
            mask = ~np.eye(rel_pop_vals.shape[0], dtype=bool)
            # Select all off-diagonal elements
            off_diag_vals = rel_pop_vals[mask]
            if len(off_diag_vals) > 0:
                self.rel_pop_mean = np.nanmean(off_diag_vals)
                # print(self.rel_pop_mean)
                self.rel_pop_std = np.nanstd(off_diag_vals)
            else:
                self.rel_pop_mean = np.nan
                self.rel_pop_std = np.nan
        self.rel_var_pop_all = rel_info['rel_var_pop_all']
        self.rel_pop_sample_size_all = rel_info['rel_pop_sample_size_all']
        self.rel_ind_all = rel_info['rel_ind_all']
        if self.rel_ind_all is not None:
            rel_ind_vals = self.rel_ind_all.values
            mask = ~np.eye(rel_ind_vals.shape[0], dtype=bool)
            off_diag_vals = rel_ind_vals[mask]
            valid_vals = off_diag_vals[~np.isnan(off_diag_vals)]
            if len(valid_vals) > 0:
                self.rel_ind_mean = np.nanmean(valid_vals)
                self.rel_ind_std = np.nanstd(valid_vals)
            else:
                self.rel_ind_mean = np.nan
                self.rel_ind_std = np.nan
        self.rel_var_ind_all = rel_info['rel_var_ind_all']
        self.rel_ind_sample_size_all = rel_info['rel_ind_sample_size_all']
        self.numeric_metrics = rel_info['numeric_metrics']
        self.data_associated_metrics = rel_info['data_associated_metrics']
        
    def sort_by_column(self, column_name):
        """Sort the results dataframe by the specified column in descending order."""
        sorted_results = self.res.sort_values(by=column_name, ascending=False)
        return sorted_results

    def run(self):
       
        self._handle_abnormality()
        # print(self.res)
        #get the abnormality of the query metric
        # print(self.query_metric, self.data_associated_metrics)
        if self.ind_df is not None and self.query_metric in self.data_associated_metrics:
            self.query_abnormality = self.res.loc[self.query_metric, 'recent_abnormality']
        else:
            self.query_abnormality = np.nan
        self._calculate_correlation_features()
        #TODO: self._calculate_mi_features()
        # print('>>>>>>>>>>>>>>>')
        self._compute_pop_posterior()
        self._compute_ind_posterior()
        self._remove_query_metric()
        #create a new df of abnormality and add query metric
        self.abnormality_df = pd.DataFrame(index=self.res.index, columns=['abnormality'])
        if self.ind_df is not None:
            self.abnormality_df['abnormality'] = self.res['recent_abnormality']
            self.abnormality_df.loc[self.query_metric, 'abnormality'] = self.query_abnormality
        # self._handle_nan_posterior()
        self._final_weighting()
  
        # self._round_all_values()
        return self.res, self.abnormality_df




def get_relationship_matrix(relationship_dict, rel_type = 'spearman', data_type = 'pop'):
    key = f'{rel_type}_{data_type}'
    temp = pd.DataFrame(relationship_dict[key])
    numeric_metrics = list(temp.columns)
    rel_all = pd.DataFrame(relationship_dict[key], index=numeric_metrics, columns=numeric_metrics)
    sample_size_all = pd.DataFrame(relationship_dict[f'sample_size_{data_type}'], index=numeric_metrics, columns=numeric_metrics)
    if rel_type == 'mi':
        var_all = pd.DataFrame(relationship_dict[f'{rel_type}_var_{data_type}'], index=numeric_metrics, columns=numeric_metrics)
        return rel_all, sample_size_all, var_all, numeric_metrics
    else:
        p_all = pd.DataFrame(relationship_dict[f'{rel_type}_p_{data_type}'], index=numeric_metrics, columns=numeric_metrics)
        return rel_all, sample_size_all, p_all, numeric_metrics




def detect_anomalies(df, query_date, n_days ):
    """
    Detect anomalies in the last n days ending on query_date by comparing against historical patterns.
    
    Parameters:
    - df: DataFrame containing the time series data
    - query_date: datetime.date or str (must match the format in df['date'])
    - n_days: Number of recent days to analyze
    
    Returns:
    - Dictionary of anomaly scores (mean absolute z-scores) per feature
    """
    # Ensure 'date' is datetime type
    if n_days == 'all':
        return {}
    
    temp_df = df.copy()
    # convert all non-numeric values to nan
    # temp_df = temp_df.applymap(lambda x: np.nan if not pd.api.types.is_numeric_dtype(x) else x)
    historical_data = temp_df.copy()
    temp_df['date'] = pd.to_datetime(temp_df['date'])
    query_date = pd.to_datetime(query_date)
    # Filter data up to query_date
    temp_df = temp_df[temp_df['date'] <= query_date].sort_values(by='date')
    # print(temp_df)
    if n_days == 'all' or len(temp_df) < n_days + 1:
        n_days = len(temp_df)

    # Edge case: not enough data
    


    # Slice recent and historical data
    recent_data = temp_df.iloc[-n_days:]
    # historical_data = df.iloc[:-n_days]

    anomaly_results = {}
    features = [col for col in temp_df.columns if col != 'date']

    for feature in features:
        if historical_data[feature].isna().all():
            anomaly_results[feature] = np.nan
            continue

        mean = historical_data[feature].mean()
        std = historical_data[feature].std()

        if std == 0 or np.isnan(std):
            anomaly_results[feature] = np.nan
        else:
            z_scores = (recent_data[feature] - mean) / std
            anomaly_scores = np.mean(np.abs(z_scores))
            anomaly_results[feature] = anomaly_scores

    return anomaly_results

def blend_openness_abnormality(x, S, sharpness = 0.7):
    # print(x)
    xmin, xmax = x.min(), x.max()
    # print("Min, Max:", xmin, xmax)

    # 1) Normalize x to [0, 1]
    if xmax == xmin:
        z = [0.5] * len(x)
    else:
        z = [(xi - xmin) / (xmax - xmin) for xi in x]
    # print("z:", z)

    # 2) Mix forward and reverse
    y = [S * zi + (1 - S) * (1 - zi) for zi in z]
    # print("y (mixed):", y)

    # 3) Z-normalize y (ignoring NaNs)
    y_mean = np.nanmean(y)
    y_std = np.nanstd(y)
    if y_std == 0:
        y2 = [0.0 if not np.isnan(yi) else np.nan for yi in y]
    else:
        y2 = [(yi - y_mean) / y_std if not np.isnan(yi) else np.nan for yi in y]

    # 4) Apply sigmoid to map to [0, 1]
    y2 = [1 / (1 + np.exp(-sharpness*yi)) if not np.isnan(yi) else np.nan for yi in y2]
    # print("y2 (sigmoid of z-score):", y2)

    return z, y, y2




