import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import combinations_with_replacement
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.feature_selection import f_regression, r_regression, mutual_info_regression

from scipy.cluster.vq import kmeans, vq
from scipy.spatial.distance import cdist
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV

import numpy as np
from itertools import combinations_with_replacement
from scipy.special import legendre, eval_chebyt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.gaussian_process.kernels import RBF

from sklearn import datasets, linear_model
from sklearn.metrics import mean_squared_error, r2_score
import yaml
from joblib import Parallel, delayed
import os
from statsmodels.nonparametric.kernel_density import KDEMultivariate
from scipy.stats import gaussian_kde

def load_data(env_name, normalize, num_chunks, gamma):

    data_path = f'data/{env_name}/'
    expert_trajs = np.load(data_path + f'expert_trajs.npy')
    expert_ts = np.load(data_path + f'expert_ts.npy')
    expert_rs = np.load(data_path + f'expert_rs.npy')

    non_expert_trajs = np.load(data_path + f'non_expert_trajs.npy')
    non_expert_ts = np.load(data_path + f'non_expert_ts.npy')
                
    traj_chunks = divide_into_chunks(expert_trajs, expert_ts, num_chunks)
    reward_chunks = divide_1d_array_into_chunks(expert_rs, expert_ts, num_chunks)
    
    non_expert_chunks = divide_into_chunks(non_expert_trajs, non_expert_ts, num_chunks)
        
    rewards = []
    for ep in reward_chunks:
        discounted_reward = 0
        for t in reversed(range(len(ep))):
            discounted_reward = ep[t] + gamma * discounted_reward
        rewards.append(discounted_reward)
    
    
    return traj_chunks, rewards, expert_ts, non_expert_chunks, non_expert_ts

    
def divide_1d_array_into_chunks(data_array, time_indexes, num_chunks):
    """
    Divide a 1D array into chunks based on sequential time indexes.

    Parameters:
    - data_array: The 1D array of shape (N,).
    - time_indexes: The array representing sequential time indexes for each point.

    Returns:
    - chunks: A list of chunks, where each chunk is a 1D array.
    """
    chunks = []
    start_index = 0
    counter = 0

    for i in range(1, len(time_indexes)):
        if time_indexes[i] < time_indexes[i - 1]:
            # If the time index resets, create a new chunk
            chunks.append(data_array[start_index:i])
            counter += 1
            if counter == num_chunks:
                break
            start_index = i

    # # # Add the last chunk
    if counter < num_chunks:
        chunks.append(data_array[start_index:])

    return chunks


def divide_into_chunks(data_array, time_indexes, num_chunks):
    """
    Divide a 2D array into chunks based on sequential time indexes.

    Parameters:
    - data_array: The 2D array of shape (N, d).
    - time_indexes: The array representing sequential time indexes for each point.

    Returns:
    - chunks: A list of chunks, where each chunk is a 2D array.
    """
    chunks = []
    start_index = 0
    counter = 0

    for i in range(1, len(time_indexes)):
        if time_indexes[i] < time_indexes[i - 1]:
            # If the time index resets, create a new chunk
            chunks.append(data_array[start_index:i, :])
            start_index = i
            counter += 1
            if counter == num_chunks:
                break
            
    # # Add the last chunk
    # chunks.append(data_array[start_index:, :])

    return chunks

def create_succ_points(trajs):
    result = []
    indices = []
    for traj in trajs:
        index = 0
        for i in range(len(traj) - 1):
            consecutive_points = np.concatenate((traj[i], traj[i+1]), axis=0)
            result.append(consecutive_points)
            indices.append(index)
            index += 1
    return np.array(result), np.array(indices)



def find_feature_trajs(num_feats, feats, gamma):
    feats_trajs = np.zeros((len(feats),num_feats))
    
    for k, traj in enumerate(feats):
        for i, feat in enumerate(traj):
            feat_discounted = feat * (gamma ** i)
            feats_trajs[k] += feat_discounted
            
    return feats_trajs 


def normalize_trajs_between_neg1_and_1(trajs):
    all_data = np.concatenate(trajs, axis=0)
    # global_min_vals = all_data.min(axis=0) + 0.000001
    # global_max_vals = all_data.max(axis=0) + 0.000002
    
    global_min_vals = 1 * np.array([-0.600, -3.220, -0.644, -0.858, -0.555, -1.072, -1.108, -0.681, -3.315, -3.527, -6.308, -20.155, -24.883, -23.288, -22.412, -25.717, -26.089])
    global_max_vals = 1 * np.array([0.378, 3.802, 0.910, 0.865, 0.873, 0.814, 1.017, 0.664, 3.221, 3.184, 6.993, 19.490, 23.201, 20.072, 25.405, 26.938, 23.360])    
    
    
    normalized_trajs = []
    for traj in trajs:
        traj_normalized = 2 * (traj - global_min_vals) / (global_max_vals - global_min_vals) - 1
        normalized_trajs.append(traj_normalized)

    normalized_trajs = np.concatenate(trajs, axis=0)
    return normalized_trajs, global_min_vals, global_max_vals


def normalize_data_point(data_point, min_vals, max_vals):
    data_point = np.clip(data_point, min_vals, max_vals)    
    normalized_data_point = 2 * (data_point - min_vals) / (max_vals - min_vals) - 1
    return normalized_data_point


    
# ##############################################################################
def compute_prob_marginal(n_batches, trajs, expert_ts, num_chunks, kde_states):
    # Method 1 - Individual states
    data = np.vstack(trajs)
    n_jobs = -1  # Use all available cores
    batch_size = int(np.ceil(len(data) / n_batches))  # For example, 10 batches
    results = Parallel(n_jobs=n_jobs)(delayed(score_subset)(kde_states, data[i:i + batch_size]) for i in range(0, len(data), batch_size))
    log_ps = np.concatenate(results)
    logP_taui = divide_1d_array_into_chunks(log_ps, expert_ts, num_chunks)
    logP_tau = [traj.sum() for traj in logP_taui]
    logP_tau = np.array(logP_tau)

    return logP_tau

def compute_prob_seq(logP_tau, n_batches, trajs, num_chunks, kde_succ_states):
    data, time_ts = create_succ_points(trajs)
    n_jobs = -1  # Use all available cores
    batch_size = int(np.ceil(len(data) / n_batches ))  # For example, 10 batches
    results = Parallel(n_jobs=n_jobs)(delayed(score_subset)(kde_succ_states, data[i:i + batch_size]) for i in range(0, len(data), batch_size))
    
    log_pss1 = np.concatenate(results)
    logP_taui = divide_1d_array_into_chunks(log_pss1, time_ts, num_chunks)
    logP_tau2 = [traj.sum() for traj in logP_taui]
    logP_tau2 = np.array(logP_tau2)
    logP_tau = logP_tau2 - logP_tau
    
    return logP_tau
    

def score_subset(kde, subset):
    return kde.score_samples(subset)

def fit_states(trajs):
    data = np.vstack(trajs)
    kde = KernelDensity(kernel='gaussian', bandwidth="silverman").fit(data)
    #kde = KernelDensity(kernel='gaussian', bandwidth=0.1).fit(data)
    return kde


def fit_succ_states(trajs):
    data, _ = create_succ_points(trajs)
    kde = KernelDensity(kernel='gaussian', bandwidth="silverman").fit(data)
    # kde = KernelDensity(kernel='gaussian', bandwidth=0.05).fit(data)
    return kde