import numpy as np
import pandas as pd
import torch
import string
from scipy.spatial.distance import euclidean
from numpy import linalg as LA
import warnings
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import FastICA

from utils.augmentation.sdforger import SDForger
    
def sdforger_augmentation(time_series_train, min_generations=10, max_generations=1500, stopping_threshold=0.98,
                          k='auto', percentage_of_variance=0.7, embedding_type='fpc', train_epochs=100,  # embedding_dim used as percentage_of_variance
                          llm='distilgpt2', text_template='base_template', permute=True, init_value=True, learning_rate=5e-5,
                          batch_size_forger=32, sdforger_float_type='float32',
                          device='mps', textual_info=None):
    """
    Augments time series data using SDForger with FPCA-based embeddings.

    Parameters:
    - time_series_train: a list of list of list or np.array of 3 dimension
        first dimension/list: number of channels
        second dimension/list: number of windows/samples
        third dimnesion/list: number of timestamps in each window (all the windows have the same number of timestamps if
        the input is np.array)
    - new_samples: Number of samples to generate.
    - percentage_of_variance: Variance explained threshold for determining embedding dimensions (default: 0.9). --> used as embeddin_dim at the moment
    - embedding_type: Type of embedding (e.g., 'fpc', 'fpc-filled').
    - train_epochs: Number of epochs for SDForger training.
    - k: Embedding dimension (if 'auto', it will be determined by `percentage_of_variance`).
    - llm: Language model for SDForger.
    - permute: Permute option for SDForger.
    - init_value: Initial value option for SDForger.
    - learning_rate: Learning rate for SDForger training.
    - device: Device for model training (e.g., 'cpu', 'mps').

    Returns:
    - Augmented time series data and embedding dimensions.
    """

    # # Handle univariate and multivariate cases with variable lengths
    # if isinstance(time_series_train[0], list):  # Check if it's a list of lists
    #     data_train = [np.array(series) for series in time_series_train]
    # else:
    #     # Convert univariate case to a list for consistency
    #     n_windows = len(time_series_train)
    #     n_points = len(time_series_train[0]) if n_windows > 0 else 0
    #     data_train = [np.array(time_series_train).reshape(n_windows, n_points)]

    # transform lists into np.array array if they are not
    print('')
    print('')
    print('-----------------------------------------------')
    print('FROM TIME SERIES TO TABULAR DATA')
    print('-----------------------------------------------')
    # print('')

    data_train = time_series_train.copy()
    if isinstance(time_series_train[0], list):  # Check if it's a list of lists
        data_train = [np.array(series) for series in time_series_train]

    n_var = len(data_train)  # Number of features

    # ---------------------------------------------------------------------------------------------------------------- #
    # Compute embeddings

    if embedding_type in ('fpc', 'fpc-filled'):
        data_embedded = []
        data_embedded_full = []
        fpc_basis = []
        fpc_basis_full = []
        var_explained = []
        embedding_dims = []
        full_embedding_dims = []
        sdforger_input = []

        for var in range(n_var):
            # Standardize each feature independently
            data_mean = data_train[var].mean(axis=0)
            data_std = data_train[var].std(axis=0)
            data_train_std = (data_train[var] - data_mean) / (data_std + 1e-32)

            # Compute FPCA with dynamic embedding_dim based on percentage_of_variance
            eigvals_var, eigenfuns = LA.eigh(data_train_std.T @ data_train_std)
            eigvals_sorted = eigvals_var[::-1]
            cumsum_eigvals = np.cumsum(eigvals_sorted) / np.sum(eigvals_sorted)

            # Determine embedding dimensions
            if k == 'auto':
                # embedding_dim = min(max(2, np.argmax(cumsum_eigvals >= percentage_of_variance) + 1), 10)
                embedding_dim = max(2, np.argmax(cumsum_eigvals >= percentage_of_variance) + 1)

            else:
                embedding_dim = k  # Use fixed embedding dimension if provided
            print(f'Latent dimension channel {var}: {embedding_dim}')
            print(f"Variance retained for channel {var}: {cumsum_eigvals[embedding_dim - 1]:.4f}")
            print('')

            full_embedding_dim = max(np.argmax(cumsum_eigvals >= 0.99) + 1, embedding_dim)
            embedding_dims.append(embedding_dim)
            full_embedding_dims.append(full_embedding_dim)

            # Store eigenfunctions and truncated eigenvalues
            data_embedded.append((data_train_std @ eigenfuns[:, -embedding_dim:]).copy())
            data_embedded_full.append((data_train_std @ eigenfuns[:, -full_embedding_dim:]).copy())
            fpc_basis.append(eigenfuns[:, -embedding_dim:])
            fpc_basis_full.append(eigenfuns[:, -full_embedding_dim:])
            var_explained.append(cumsum_eigvals[embedding_dim - 1])

            # Prepare SDForger input
            sdforger_input.append(data_embedded[-1])

        sdforger_input = np.hstack(sdforger_input)

        # demo 2025_01_16
        sdforger_input_df = pd.DataFrame(sdforger_input, 
                                         columns=["channel_"+str(c)+"_pca_"+str(pca) 
                                                  for c in range(0, len(embedding_dims)) 
                                                  for pca in range(0, embedding_dims[c])])
        if textual_info is not None:
            sdforger_input_df['data']=np.array(textual_info)
            
        print('')
        pd.set_option('display.max_columns', None)
        print(sdforger_input_df)
        pd.reset_option('display.max_columns')

    elif embedding_type in ('fica'):

        fica_mixing = []
        fica_mean = []
        embedding_dims = []
        sdforger_input = []
        var_explained = []

        for var in range(n_var):
            data_train_std = data_train[var]

            # Determine optimal k (number of components)
            if k == 'auto':
                k_auto = 2
                max_k = 30
                while k_auto <= max_k:
                    with warnings.catch_warnings():
                        warnings.simplefilter("ignore")
                        fica = FastICA(n_components=k_auto, random_state=0)
                        data_embedded = fica.fit_transform(data_train_std)
                        data_reconstructed = data_embedded @ fica.mixing_.T + fica.mean_

                    # Compute total variance retained across all samples and features
                    original_var = np.var(data_train_std, axis=1).sum()
                    reconstructed_var = np.var(data_reconstructed, axis=1).sum()
                    retained = reconstructed_var / original_var

                    if retained >= percentage_of_variance:
                        break
                    k_auto += 1
                embedding_dim = k_auto
            else:
                embedding_dim = k

            print(f'Latent dimension channel {var}: {embedding_dim}')
            fica = FastICA(n_components=embedding_dim, random_state=0)
            data_embedded = fica.fit_transform(data_train_std)
            data_reconstructed = data_embedded @ fica.mixing_.T + fica.mean_

            # Compute retained variance (final report)
            original_var = np.var(data_train_std, axis=1).sum()
            reconstructed_var = np.var(data_reconstructed, axis=1).sum()
            retained = reconstructed_var / original_var
            print(f"Variance retained for channel {var}: {retained:.4f}")
            print('')

            embedding_dims.append(embedding_dim)
            fica_mixing.append(fica.mixing_)
            fica_mean.append(fica.mean_)
            sdforger_input.append(data_embedded)
            var_explained.append(retained)

        sdforger_input = np.hstack(sdforger_input)

        # demo 2025_01_16
        sdforger_input_df = pd.DataFrame(sdforger_input, 
                                         columns=["channel_"+str(c)+"_fica_"+str(fica) 
                                                  for c in range(0, n_var) 
                                                  for fica in range(0, embedding_dims[c])])
        if textual_info is not None:
            sdforger_input_df['data']=np.array(textual_info)

        print('')
        pd.set_option('display.max_columns', None)
        print(sdforger_input_df)
        pd.reset_option('display.max_columns')

    else:
        raise ValueError(f"Possible embedding types: ['fpc', 'fpc-filled', 'fica']")

    # ---------------------------------------------------------------------------------------------------------------- #
    # Fitting SDForger

    df_input = pd.DataFrame(sdforger_input, columns=['value_' + str(k) for k in range(0, sdforger_input.shape[1])])
    if textual_info is not None:
        df_input['data']=np.array(textual_info)
    pd.set_option('display.max_columns', None)
    print(df_input)
    pd.reset_option('display.max_columns')

    total_embedding_dim = sum(embedding_dims)
    total_instances = sdforger_input_df.shape[0]
    print('')
    print(f"Total embedding dimension: {total_embedding_dim}")
    print(f"Total instancers: {total_instances}")

    if total_embedding_dim > 25 or total_embedding_dim > 0.8 * total_instances:
        warnings.warn(
            f"The total embedding dimension ({total_embedding_dim}) is large. "
            f"If you have few instances ({total_instances}), you may want to decrease the percentage "
            f"of variance explained or specify a smaller 'k' as input.",
            UserWarning
        )

    model = SDForger(model_path=llm, text_template=text_template, float_type=sdforger_float_type)

    # batch_size_forger = 32
    # TODO: uncomment print
    print('')
    print(f'BATCH SIZE = {batch_size_forger}\n')

    model.fit(df_input, batch_size=batch_size_forger, epochs=train_epochs, hf_trainer=True, permute=permute,
              learning_rate=learning_rate, embedded_dims=embedding_dims)
    model.model.to(device)

    # check_distribution = not init_value
    sdforger_output = model.generate(n_samples_min=min_generations, n_samples_max=max_generations,
                                     stopping_treshold=stopping_threshold, max_length=10000, init_value=init_value,
                                     check_distribution=True)
    sdforger_output = sdforger_output.to_numpy()

    # ---------------------------------------------------------------------------------------------------------------- #
    # Retrieve to original form

    # Reshape generated data to original format (for multivariate case)
    index = 0
    new_data_embedded = []
    for i in range(n_var):
        print(i)
        k_i = embedding_dims[i]  # Number of basis components
        new_data_embedded.append(sdforger_output[:, index:index + k_i])
        index += k_i

    print('')
    print('')
    print('')
    print('-----------------------------------------------')
    print('FROM TEXT TO TABULAR DATA')
    print('-----------------------------------------------')

    # pd.set_option('display.max_columns', None)
    augmented_data = pd.DataFrame(sdforger_output)
    augmented_data.columns = sdforger_input_df.columns
    print(pd.DataFrame(augmented_data))
    # pd.reset_option('display.max_columns')

    print('')
    print('')
    print('')
    print('-----------------------------------------------')
    print('FROM TABULAR DATA TO CURVES')
    print('-----------------------------------------------')

    # Retrieve to original form based on FPCA components
    new_data = []
    new_data_full = [] if embedding_type == 'fpc-filled' else None  # Initialize for fpc-filled if required

    if embedding_type in ('fpc', 'fpc-filled'):
        for var in range(n_var):
            # Retrieve the generated data back to the original feature space
            reconstructed_data = data_train[var].std(axis=0) * (new_data_embedded[var] @ fpc_basis[var].T) + data_train[
                var].mean(axis=0)
            new_data.append(reconstructed_data)

            # For 'fpc-filled', add the generated data to the original data in the embedded space
            if embedding_type == 'fpc-filled':
                sampled_indices = np.random.choice(data_embedded_full[var].shape[0],
                                                   size=new_data_embedded[var].shape[0], replace=True)
                data_embedded_full_var = data_embedded_full[var][sampled_indices]
                data_embedded_full_var[:, -embedding_dims[var]:] = new_data_embedded[var]
                reconstructed_full = data_train[var].std(axis=0) * (data_embedded_full_var @ fpc_basis_full[var].T) + \
                                     data_train[var].mean(axis=0)
                new_data_full.append(reconstructed_full)
    
    elif embedding_type in ('fica'):
        for var in range(n_var):
            # Retrieve the generated data back to the original feature space
            print('new_data_embedded[var].shape', new_data_embedded[var].shape)
            reconstructed_data = ( np.dot(new_data_embedded[var], fica_mixing[var].T) + fica_mean[var] )
            new_data.append(reconstructed_data)


    else:
        raise ValueError(f"Multivariate augmentation NOT IMPLEMENTED FOR B-SPLINES YET.")

    return ((new_data_full if embedding_type == 'fpc-filled' else new_data),
            sdforger_input, new_data_embedded, embedding_dims, var_explained)


def plot_generated_data(output_path, original_data, new_data):
    for var in range(new_data.shape[0]):
        plt.figure(figsize=(15,6))
        plt.plot(new_data[var, :, 0:200].T, '--', color='orange', alpha=0.3)
        plt.plot(np.array(original_data[var])[:, 0:200].T, lw=1, alpha=0.8)
        plt.savefig(os.path.join(output_path, f'generated_data_plot.pdf'))
        plt.close()