'''
Basic functions for I/O.
'''
import numpy as np
import torch
import datetime
import json
from typing import Any, Tuple

import pandas as pd
from sklearn.neighbors import NearestNeighbors

#* ==========================================================
#* Parameters, data, and logging
#* ==========================================================

class Parameters():
    '''
    Load and use parameters in .json file.
    '''
    def __init__(self, json_name: str, dict_name="Model") -> None:
        
        self.dictionary = self.load_parameters(json_name, dict_name)
        
    @staticmethod
    def load_parameters(json_name: str, dict_name="Model") -> dict:
        '''
        Load dictionaries from .json file.
        
        Parameters
        -----------------
        json_name: str
            json file name
            
        dict_name: str
            dictionary name.
        
        Returns
        -----------------
        dictionary: dict
            parameter dictionary
        '''
        dicts = []
        with open(json_name, 'r') as f:
            dicts = json.load(f)
            
        dictionary = None
        
        for d in dicts:
            if dict_name == d['DictName']:
                dictionary = d
        
        if dictionary is None:
            raise Exception('There is not a dictionary with DictName = %s'%(dict_name))
        
        return dictionary

    def __call__(self, key: str) -> Any:
        '''
        Get the value of `key` in the dictionary.
        '''
        return self.dictionary[key][1]

    def help(self, key: str) -> str:
        '''
        Get the help of `key` in the dictionary.
        
        Each value is a list of [string of help, value].
        '''
        return self.dictionary[key][0]

    def value(self, key: str) -> Any:
        '''
        Get the value of `key` in the dictionary.
        
        Each value is a list of [string of help, value].
        '''
        return self.dictionary[key][1]
    
    def set_value(self, key: str, value) -> None:
        '''
        Set value for `key` in the dictionary.
        
        Each value is a list of [string of help, value].
        '''
        self.dictionary[key][1] = value

    def has_key(self, key: str) -> bool:
        '''
        Check whether the key is in the dictionary.
        '''
        return key in self.dictionary.keys()


def create_tensor(data, gpu_id=0, device=None, requires_grad=True) -> torch.Tensor:
    '''
    When `cuda` is available, create a tensor in the GPU device #`gpu_id`.
    '''
    if device is None:
        if torch.cuda.is_available():
            device='cuda:%d'%(gpu_id)
        else:
            device='cpu'

    if isinstance(data, list):
        data = np.array(data)

    return torch.tensor(data, dtype=torch.float32, device=device).requires_grad_(requires_grad)

def get_lr(optimizer: torch.optim.Optimizer) -> float:
    '''
    Get the learning rate of the optimizer.
    '''
    return optimizer.state_dict()['param_groups'][0]['lr']
 
 
def init_log(folder_result, fname='logging.log') -> None:
    '''
    Initialize logging
    '''
    f0 = open(fname, 'w')
    f0.write('\n')
    now_time = datetime.datetime.now()
    f0.write('Time:        '+now_time.strftime('%Y-%m-%d %H:%M:%S \n'))
    f0.write('Result path: ' + str(folder_result) + '\n')

    f0.write('\n')
    f0.write('============================== \n')
    f0.write('\n')
    f0.close()
    
def log(text: str, prefix='', show_time=True, fname='logging.log') -> None:
    '''
    Log time and text
    '''
    print(prefix+text)
    
    if fname is None:
        return

    _time = ''
    if show_time:
        now_time = datetime.datetime.now()
        _time = now_time.strftime('%Y-%m-%d %H:%M:%S | ')
    
    with open(fname, 'a') as f:
        f.write(_time+prefix+text+'\n')


def loss_history(epoch: int, info: dict, fname_loss='loss-history.dat') -> None:
    '''
    info.keys = ['epoch', 'time(s)', 'loss', ..., 'lr']
    '''
    keys = info.keys()
    
    if epoch<=0:
        with open(fname_loss, 'w') as f:
            f.write('Variables=')
            for key in keys:
                f.write(' %19s'%(key))
            f.write('\n')
    
    with open(fname_loss, 'a') as f:
        f.write('          ')
        for key in keys:
            f.write(' %19.10E'%(info[key]))
        f.write('\n')


#* ==========================================================
#* Statical analysis
#* ==========================================================

def cal_total_variance(y: np.ndarray) -> np.ndarray:
    '''
    Calculate the total QoI variance in the system.
    
    - V_t = Var_{x~p(x), e~p(e)}[y] = V_a + V_m
    - V_a = E_{x~p(x)}[Var_{e~p(e)}[y|x,e]]
    - V_m = Var_{x~p(x)}[E_{e~p(e)}[y|x,e]]
    
    Parameters
    ----------
    y : ndarray [num_samples, dim_output]
        Input data.
        
    Returns
    -------
    V_t : ndarray [dim_output]
        Total variance.
    '''
    return np.var(y, axis=0)

def cal_Va_Vm_from_data(x: np.ndarray, y: np.ndarray, n_neighbor=3, 
                    ratio_neighbor=0.01) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    '''
    Calculate the average aleatoric uncertainty (V_a) 
    and the variance of the mean function (V_m) in the system.
    
    - V_a = E_{x~p(x)}[Var_{e~p(e)}[y|x,e]]
    - V_m = Var_{x~p(x)}[E_{e~p(e)}[y|x,e]]
    
    Parameters
    ----------
    x : ndarray [num_samples, dim_input]
        Input data.
    
    y : ndarray [num_samples, dim_output]
        Noise data.
        
    n_neighbor : int
        Number of neighbors for grouping.
        
    ratio_neighbor : float
        Ratio of the number of neighbors for grouping.
        
    Returns
    -------
    V_a : ndarray [dim_output]
        Average aleatoric uncertainty.
        
    V_m : ndarray [dim_output]
        Variance of the mean function.
        
    V_noise : ndarray [dim_output]
        Variance of the aleatoric uncertainty.
    '''
    num_samples = x.shape[0]
    dim_input = x.shape[1]
    dim_output = y.shape[1]
    
    #* Standardize the input data
    x = (x - np.mean(x, axis=0)) / (np.std(x, axis=0) + 1e-8)

    # dataset: N samples of (x, y)
    data = {}
    for i in range(dim_input):
        data['x'+str(i)] = x[:,i]
        
    for i in range(dim_output):
        data['y'+str(i)] = y[:,i]
    
    name_x = ['x'+str(i) for i in range(dim_input)]

    # Convert to a DataFrame for convenience
    df = pd.DataFrame(data)

    # Define the number of neighbors for grouping
    n_neighbors = max(n_neighbor, int(ratio_neighbor*num_samples))  
    
    # Group by similar x using clustering (e.g., Nearest Neighbors)
    nn = NearestNeighbors(n_neighbors=n_neighbors).fit(df[name_x])
    _, indices = nn.kneighbors(df[name_x])

    # Calculate the variance of y 
    V_a = np.zeros(dim_output)
    V_m = np.zeros(dim_output)
    V_noise = np.zeros(dim_output)
    
    for i in range(dim_output):

        # Compute the variance of y for neighbors of each sample
        variances_x = []
        means_x = []
        # ii = 0
        for idx_group in indices:
            group_y = df.iloc[idx_group]['y'+str(i)].values
            variances_x.append(np.var(group_y))
            means_x.append(np.mean(group_y))

        # Compute the expected variance
        V_a[i] = np.mean(variances_x)
        V_m[i] = np.var(means_x)
        V_noise[i] = np.var(np.sqrt(np.array(variances_x)))
    
    return V_a, V_m, V_noise

def cal_Va_Vm_from_kmeans(x: np.ndarray, y: np.ndarray, 
            n_neighbor=3, ratio_neighbor=0.01) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    '''
    Calculate the average aleatoric uncertainty (V_a) 
    and the variance of the mean function (V_m) using k-means clustering.
    
    - V_a = E_{x~p(x)}[Var_{e~p(e)}[y|x,e]]
    - V_m = Var_{x~p(x)}[E_{e~p(e)}[y|x,e]]
    
    Parameters
    ----------
    x : ndarray [num_samples, dim_input]
        Input data.
    
    y : ndarray [num_samples, dim_output]
        Noise data.
        
    n_clusters : int
        Number of clusters for k-means algorithm.
        
    Returns
    -------
    V_a : ndarray [dim_output]
        Average aleatoric uncertainty.
        
    V_m : ndarray [dim_output]
        Variance of the mean function.
        
    V_noise : ndarray [dim_output]
        Variance of the aleatoric uncertainty.
    '''
    from sklearn.cluster import KMeans
    
    num_samples = x.shape[0]
    dim_input = x.shape[1]
    dim_output = y.shape[1]
    
    # Standardize the input data
    x_std = (x - np.mean(x, axis=0)) / (np.std(x, axis=0) + 1e-8)

    # Apply k-means clustering to group similar x values
    n_clusters = min(int(1.0/ratio_neighbor), int(num_samples/n_neighbor))
    
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(x_std)
    cluster_labels = kmeans.labels_
    
    # Calculate the variance metrics
    V_a = np.zeros(dim_output)
    V_m = np.zeros(dim_output)
    V_noise = np.zeros(dim_output)
    
    # Process each output dimension
    for i in range(dim_output):
        variances_x = []
        means_x = []
        
        # Process each cluster
        for cluster_id in range(n_clusters):
            # Get indices of points in this cluster
            cluster_indices = np.where(cluster_labels == cluster_id)[0]
            
            # Skip empty clusters (shouldn't happen with k-means but just in case)
            if len(cluster_indices) == 0:
                continue
                
            # Get y values for this cluster
            cluster_y = y[cluster_indices, i]
            
            # Calculate variance and mean for this cluster
            variances_x.append(np.var(cluster_y))
            means_x.append(np.mean(cluster_y))
        
        # Compute the expected variance (V_a) - mean of the within-cluster variances
        V_a[i] = np.mean(variances_x)
        
        # Compute the variance of the means (V_m) - variance of the cluster means
        V_m[i] = np.var(means_x)
        
        # Compute the variance of the standard deviations (V_noise)
        V_noise[i] = np.var(np.sqrt(np.array(variances_x)))
    
    return V_a, V_m, V_noise
