import pandas as pd
import numpy as np
import torch
import re
import copy
from dice_ml import Data, Model, Dice
import multiprocessing as mp
from tqdm import tqdm
from sklearn.neighbors import KDTree
from sklearn.metrics import DistanceMetric
from abc import ABC
import carla.recourse_methods.catalog as recourse_catalog
from lgmvae import *

# Dice
def get_dice_explainer(X_train, y_train, clf):
    dummy_fnames = [f'feature_{i}' for i in range(X_train.shape[1])]
    whole_df = pd.DataFrame(X_train, columns=dummy_fnames)
    whole_df['target'] = y_train
    dice_data = Data(dataframe=whole_df, continuous_features=dummy_fnames, outcome_name='target')
    dice_model = Model(clf, backend='sklearn')
    dice_exp = Dice(dice_data, dice_model, method="random")
    return dice_exp

def dice_worker(queue, dice_exp, x_df, total_cfs):
    """
    A worker function that runs the DiCE generation and puts the result in a queue.
    """
    try:
        dice_ce = dice_exp.generate_counterfactuals(x_df, total_CFs=total_cfs, verbose=False)
        if dice_ce.cf_examples_list and dice_ce.cf_examples_list[0].final_cfs_df is not None:
            result = pd.concat([cf.final_cfs_df for cf in dice_ce.cf_examples_list], ignore_index=True).values[:, :-1]
            queue.put(result)
        else:
            queue.put(None) # Indicate that no CFs were found
    except Exception as e:
        # Pass any exceptions back to the main process
        queue.put(e)

def generate_dice_ce(X_test_ce, dice_exp):
    dice_ces = []
    timeout_seconds = 2
        
    for i, x in enumerate(X_test_ce):
        # The queue is used to get the result back from the separate process
        q = mp.Queue()
        x_df = pd.DataFrame(x.reshape(1, -1), columns=[f'feature_{i}' for i in range(X_test_ce.shape[1])])

        # Create the separate process
        p = mp.Process(target=dice_worker, args=(q, dice_exp, x_df, 5))
        
        p.start()
        p.join(timeout=timeout_seconds) # Wait for the process to finish, with a timeout

        if p.is_alive():
            # If the process is still running after the timeout, terminate
            p.terminate()
            p.join() # Clean up
            result = "timeout"
        else:
            # If the process finished, get the result from the queue
            result = q.get()

        # --- Append results ---
        if isinstance(result, np.ndarray):
            dice_ces.append(result)
        else:
            dice_ces.append(np.ones((5, X_test_ce.shape[1])) * -1)
    return dice_ces

# nnce
def get_nnce_tree(X_train, clf):
    X1_class1_clf = X_train[clf.predict(X_train) == 1]
    if len(X1_class1_clf) > 1000:
        np.random.shuffle(X1_class1_clf)
        X1_class1_clf = X1_class1_clf[:1000]
    dist = DistanceMetric.get_metric("manhattan")
    tree = KDTree(X1_class1_clf, leaf_size=40, metric=dist)
    return tree, X1_class1_clf

def generate_nn_ce(X_test_ce, X1_class1_clf, tree):
    idxs = np.array(tree.query(X_test_ce)[1]).flatten()
    tree_ces = X1_class1_clf[idxs]
    return tree_ces

def generate_ic_ce(X_test_ce, X1_class1_clf, tree, clf):
    dist_metric = DistanceMetric.get_metric("manhattan")
    ces = np.zeros((len(X_test_ce), 5, X_test_ce.shape[1]))
    dists, idxs = tree.query(X_test_ce, k=min(1000, len(X1_class1_clf)))

    for i, x in enumerate(X_test_ce):
        this_ces = np.zeros((1, X_test_ce.shape[1]))
        # first get closest counterfactual
        smallest_dist = dists[i][0]
        this_ces[0] = X1_class1_clf[idxs[i][0]]
        # greedily find the next 4 counterfactual points
        dist_thres = 1.5 * smallest_dist
        for j in range(4):
            for k, idx in enumerate(idxs[i]):
                current_point = X1_class1_clf[idxs[i][k]].reshape(1, -1)
                if (dist_metric.pairwise(current_point, this_ces) >= dist_thres).all():
                    this_ces = np.concatenate([this_ces, current_point], axis=0)
                    break
        # sanity check
        if this_ces.shape[0] != 5:
            this_ces = np.tile(X1_class1_clf[idxs[i][0]], (5,1))
        else:
            # line search to get closest counterfactual
            for j in range(5):
                current_ce = copy.deepcopy(this_ces[j])
                for k in np.arange(0, 1.01, 0.05):
                    current_ce = k * current_ce + (1-k) * X_test_ce[i]
                    if clf.predict(current_ce.reshape(1, -1))[0] == 1:
                        this_ces[j] = current_ce
                        break
        ces[i] = this_ces
    return ces

# STABLE CE
def counterfactual_stability(xp, model):
    # use std=0.1 (variance=0.01), k=1000 as in the original paper
    score_x = model.predict_proba(xp.reshape(1, -1))[0][1]
    gaussian_samples = np.random.normal(xp, 0.1, (1000, len(xp)))
    model_scores = model.predict_proba(gaussian_samples)[:, 1]
    return np.sum((model_scores - np.abs(model_scores - score_x)) / len(model_scores))

def generate_stable_ce(X_test_ce, X_class1_clf, model, tree, threshold=0.45):
    ccf_ces = []
    for i, x in enumerate(X_test_ce):
        # original prediction's class probability for class 1, scalar
        for k in range(1, len(X_class1_clf)):
            nn1 = X_class1_clf[tree.query(x.reshape(1, -1), k=k)[1].flatten()[-1]]
            if counterfactual_stability(nn1, model) >= threshold:
                ccf_ces.append(nn1)
                break
    return ccf_ces

# FACE
class MLModel(ABC):
    """
    wrapper model class for interfacing carla library
    """

    def __init__(
        self,
        data: Data,
        m
    ) -> None:
        self._data: Data = data
        self._model = m
    @property
    def data(self) -> Data:
        return self._data

    @data.setter
    def data(self, data: Data) -> None:
        self._data = data

    @property
    def backend(self):
        return "sklearn"

    @property
    def raw_model(self):
        return self._model

    def predict(self, x):
        return self._model.predict(x)

    def predict_proba(self, x):
        return self._model.predict_proba(x)

    def get_ordered_features(self, x):
        return x

    def get_mutable_mask(self):
        return np.ones(len(self.data.shape[1]), dtype=bool)
    
def get_face_explainer(X_train, clf):
    hyperparams_face = {"mode": "knn", "fraction": 0.3}
    dummy_fnames = [f'feature_{i}' for i in range(X_train.shape[1])]
    df_X_train = pd.DataFrame(X_train, columns=dummy_fnames)
    clf_carla = MLModel(df_X_train, clf)
    face = recourse_catalog.Face(clf_carla, hyperparams_face)
    return face

def generate_face_ce(X_test_ce, face):
    face_ces = face.get_counterfactuals(pd.DataFrame(X_test_ce, columns=[f'feature_{i}' for i in range(X_test_ce.shape[1])]))
    return np.array(face_ces)

# ours
def generate_ours_ce(X_test_ce, model, clf, cluster_centroids, device):
    ours_ces_first = []
    ours_ces_last = []
    ours_ces_middle = []
    paths = []
    for i, x in enumerate(X_test_ce):
        cfs = generate_counterfactuals_with_steps_centroid(model, torch.from_numpy(x).to(torch.float32), 0, 1, cluster_centroids, device).cpu().numpy()
        this_final_cf = []
        this_final_cf_last = []
        this_final_cf_middle = []
        paths.append(cfs)
        for path in cfs:
            preds = clf.predict(path)
            first_idx = np.where(preds==1)[0][0]
            this_final_cf.append(path[first_idx])
            this_final_cf_last.append(path[-1])
            middle_idx = int(np.floor((first_idx + len(path))/2))
            try:
                while middle_idx <= len(path):
                    if clf.predict(path[middle_idx].reshape(1, -1)) != 1:
                        middle_idx += 1
                    else:
                        break
                ce = path[middle_idx]
                this_final_cf_middle.append(ce)
            except:
                this_final_cf_middle.append(path[-1])
        ours_ces_first.append(this_final_cf)
        ours_ces_last.append(this_final_cf_last)
        ours_ces_middle.append(this_final_cf_middle)
    ours_ces_first = np.array(ours_ces_first)
    ours_ces_last = np.array(ours_ces_last)
    ours_ces_middle = np.array(ours_ces_middle)
    return ours_ces_first, ours_ces_middle, ours_ces_last, np.array(paths)

# Path CEs
def generate_ours_path_ce(X_test_ce, model, clf, cluster_centroids, device, rules_lists):
    paths = torch.zeros((len(X_test_ce), int(model.c_dim/model.y_dim), 21, X_test_ce.shape[-1]), device="cpu")
    X_test_ce_torch = torch.from_numpy(X_test_ce).float()
    for i, x in enumerate(X_test_ce_torch):
        loss_fn = create_multifaceted_constraint_fn(x, rules_lists[i])
        paths[i] = generate_cf_with_correction(model, x, 0, 1, cluster_centroids, loss_fn, device, clf, correction_steps=10, correction_lr=0.1)
    return paths.cpu().numpy()

# methods to get paths, for actionability part.
def get_path_ce(X_test_ce, ces, dataset, tree, method="interpolate"):
    '''
    go from single CE to path CE. Interpolate is simple linear interpolation, greedy is finding points in the dataset.
    '''
    if method == "interpolate":
        tau = np.arange(0, 1.01, 0.05).reshape(1, -1, 1)
        start_points = X_test_ce[:, np.newaxis, :]
        end_points = ces[:, np.newaxis, :]
        interpolated_paths = (1 - tau) * start_points + tau * end_points
        return interpolated_paths
    elif method == "greedy":
        paths = np.zeros((len(X_test_ce), 21, X_test_ce.shape[-1]))
        for i, x in enumerate(X_test_ce):
            paths[i] = find_greedy_path(x, ces[i], dataset, tree, 20, 10)
        return paths

def find_greedy_path(
    start_point, 
    end_point, 
    dataset, 
    tree, 
    path_length=20, 
    k=10
):
    """
    Finds a path of a specific length through a dataset using a greedy
    search on a k-NN graph represented by a KDTree.

    Args:
        start_point (np.ndarray): The starting data point (1, n_features).
        end_point (np.ndarray): The target data point (1, n_features).
        dataset (np.ndarray): The dataset of points (n_samples, n_features).
        tree (KDTree): A pre-computed KDTree of the dataset.
        path_length (int): The desired number of points in the path.
        k (int): The number of nearest neighbors to consider at each step.

    Returns:
        np.ndarray or None: An array of shape (path_length, n_features)
                            representing the path, or None if it gets stuck.
    """
    if path_length < 2:
        raise ValueError("Path length must be at least 2.")

    path_indices = []
    path_indices_set = set()

    # Find the closest dataset point to the start
    _, initial_idx = tree.query(start_point.reshape(1, -1), k=1)
    current_idx = initial_idx[0][0]
    path_indices.append(current_idx)
    path_indices_set.add(current_idx)

    # Iteratively build the path
    for _ in range(path_length - 2):
        current_point = dataset[current_idx].reshape(1, -1)
        
        # Find k neighbors of the current point
        _, neighbor_indices = tree.query(current_point.reshape(1, -1), k=k+1) # +1 to exclude self
        neighbor_indices = neighbor_indices[0]

        best_next_idx = -1
        min_dist_to_target = float('inf')

        # From the neighbors, find the one closest to the final end_point
        for idx in neighbor_indices:
            if idx == current_idx or idx in path_indices_set:
                continue

            dist_to_target = np.linalg.norm(dataset[idx] - end_point)
            
            if dist_to_target < min_dist_to_target:
                min_dist_to_target = dist_to_target
                best_next_idx = idx

        if best_next_idx == -1:
            print("Warning: Greedy search got stuck. No valid next step found.")
            return None

        # Add the best neighbor to the path
        current_idx = best_next_idx
        path_indices.append(current_idx)
        path_indices_set.add(current_idx)

    # final path
    path_from_dataset = dataset[path_indices]
    
    # Combine with the original start and end points
    final_path = np.vstack([start_point, path_from_dataset, end_point])

    return final_path
