import numpy as np
from matplotlib import pyplot as plt
from numpy.random import SeedSequence
import pandas as pd
from datetime import datetime, timedelta
import seaborn as sns

from tigramite import data_processing as pp
import tigramite.plotting as tp


def load_dataset(csv_path, standardize=True):
    # Load DataFrame
    df = pd.read_csv(csv_path)
    pivoted_df = df.pivot(index='time', columns='name', values='value')
    pivoted_df.index = pd.to_datetime(pivoted_df.index)

    if standardize:
        for i in range(len(pivoted_df.columns)):
            mean = pivoted_df.iloc[:, i].mean()
            pivoted_df.iloc[:, i] = (pivoted_df.iloc[:, i] - mean) / pivoted_df.iloc[:, i].std()

    return pivoted_df


def get_dataset_with_context_indicator(pivoted_df, context_variable='swvl3', ranges=None,
                                       years_to_keep=None, 
                                       subsample=True, subsampling_type='subsample', subsample_freq='3D'):
    if ranges is not None and len(ranges) != 2:
        raise ValueError("Ranges must contain exactly two values defining the intervals.")
    
    if ranges[0] < 1.:
        range_1 = np.quantile(pivoted_df[context_variable], ranges[0])
        range_2 = np.quantile(pivoted_df[context_variable], ranges[1])
    else:
        range_1 = ranges[0]
        range_2 = ranges[1]

    df_with_context_indicator = pivoted_df.copy()

    # Define conditions for the two specific intervals
    conditions = [
        (df_with_context_indicator[context_variable] > -np.inf) & (df_with_context_indicator[context_variable] <= range_1),
        (df_with_context_indicator[context_variable] > range_1) & (df_with_context_indicator[context_variable] <= range_2),
        (df_with_context_indicator[context_variable] > range_2) & (df_with_context_indicator[context_variable] <= np.inf)
    ]

    # Define the corresponding category labels for these intervals
    choices = [0, 1, 2]

    # Use np.select to assign categories, defaulting to 2 for values not meeting the conditions
    df_with_context_indicator['context'] = np.select(conditions, choices, default=0)
    if years_to_keep is not None:
        # only keep years starting from 1990
        all_days_to_keep = np.arange(datetime(years_to_keep[0], 1, 1), datetime(years_to_keep[1], 12, 31),
                                        timedelta(days=1)).astype(datetime)
        # dates_to_keep = [i for i in all_days_to_keep if i in df_with_context_indicator.index]
        df_with_context_indicator = df_with_context_indicator.reindex(all_days_to_keep, fill_value=999.)

    if subsample:
        # Resample data
        if subsampling_type == 'mean':
            resampled = df_with_context_indicator.copy().resample(subsample_freq).mean().dropna()
        elif subsampling_type == 'subsample':
            # Assuming 'subsample_freq' contains a valid integer for frequency
            resampled = df_with_context_indicator.copy().iloc[::int(subsample_freq[0])]
        else:
            raise ValueError('This type of sampling is not supported!')
    else: 
        return df_with_context_indicator

    return resampled


def get_pandas_and_tigramite_dfs(pd_df, vals_to_drop=None, columns_name=None, context_var_name='swvl123', drop_context=True):
    if columns_name is not None:
    # rename columns
        pd_df = pd_df.rename(columns=columns_name)
        # Drop specified columns if provided
    if vals_to_drop is not None:
        pd_df.drop(vals_to_drop, axis=1, inplace=True)
    if drop_context:
        pd_df.drop(columns_name[context_var_name], axis=1, inplace=True)
    # Convert to new DataFrame with appropriate column names
    tig_df = pp.DataFrame(pd_df.to_numpy(), var_names=pd_df.columns, missing_flag=999.)

    return pd_df, tig_df


def get_dataset_for_era5(csv_path, context_variable='swvl3', ranges=None, 
                         years_to_keep=None,
                         subsample=True, subsampling_type='subsample', subsample_freq='3D',
                             vals_to_drop=None, columns_name=None,
                        drop_context=True):
    # Load the dataset
    pivoted_df = load_dataset(csv_path)

    # Add context indicators
    df_with_context_indicator = get_dataset_with_context_indicator(
        pivoted_df, context_variable=context_variable, ranges=ranges, years_to_keep=years_to_keep,
        subsample=subsample, subsampling_type=subsampling_type, subsample_freq=subsample_freq
    )

    # Get pandas and tigramite DataFrames
    pd_df, tig_df = get_pandas_and_tigramite_dfs(df_with_context_indicator, vals_to_drop=vals_to_drop, 
                                                 columns_name=columns_name,
                                                 context_var_name=context_variable,
                                                 drop_context=drop_context)

    type_mask = np.zeros(pd_df.to_numpy().shape)
    type_mask[:, -1] = 1

    return pd_df, tig_df, type_mask
