import time
from typing import Optional

import numpy as np
from torch.utils.data import Dataset
import scipy
import networkx as nx
from sklearn.linear_model import LassoLarsIC
import gies
import wandb

from .base._base_model import BaseModel

_DEFAULT_MODEL_KWARGS = dict()

def _sort_ranking(score_matrix, expr, masks, lmbda):
    flat_array = score_matrix.flatten()
    G = nx.DiGraph()
    
    # Argsort on the flattened array
    sorted_flat_indices = np.argsort(-flat_array)
    mean_score = np.median(score_matrix)
    var_score = np.std(score_matrix)
    print(mean_score, var_score)

    # Mapping flat indices back to (i, j) format
    rows, cols = score_matrix.shape
    G.add_nodes_from(range(cols))
    sorted_indices_ij = np.unravel_index(sorted_flat_indices, (rows, cols))
    for k in range(len(sorted_indices_ij[0])):
        i, j = sorted_indices_ij[0][k], sorted_indices_ij[1][k]
        if i != j:
            score = score_matrix[i, j]
            if score > lmbda: 
                G.add_edge(i, j)
                if not nx.is_directed_acyclic_graph(G):
                    G.remove_edge(i, j)
                
    W = np.zeros((cols, cols))
    for i, nbrdict in G.adjacency():
        for j in nbrdict.keys():
            W[i, j] = 1
        
    
    return W

def _compute_scores(data, masks, d):
    obs_indices = np.where(masks.all(axis=1))[0]
    print(len(obs_indices))
    score_matrix = np.zeros((d, d))
    for node in range(d):
        for var in range(d):
            if node != var: 
                data_obs = data[obs_indices, var]
                int_indices = np.where(1 - masks[:, node])[0]
                data_inter = data[int_indices, var]
                if len(data_inter) > 0:
                    w_dist = scipy.stats.wasserstein_distance(data_obs, data_inter)
                    score_matrix[node, var] = w_dist
    return score_matrix
   
def score_ordering(topological_order, score_matrix, d, eps=0.3):
    tot = 0
    before = list()
    after = list(range(d))
    for i in topological_order:
        after.remove(i)
        if np.any(score_matrix[i, :] > 0.0):
            positive = np.sum(score_matrix[i, after] - eps)
            tot += positive
        before.append(i)
    return tot


def move_variable(perm, from_index, to_index):
    """Move a variable from from_index to to_index in the permutation."""
    if from_index == to_index:  # No move needed
        return perm
    new_perm = perm.copy()
    new_perm.insert(to_index, new_perm.pop(from_index))
    return new_perm

def generate_all_possible_moves(perm):
    """Generate all possible moves of a variable to any position."""
    moves = []
    for i in range(len(perm)):
        for j in range(len(perm)):
            if i != j:
                # Generate a move by placing i-th element to j-th position
                moved_perm = move_variable(perm, i, j)
                moves.append(moved_perm)
    return moves

def local_search_extended(initial_perm, score_matrix, d):
    """Perform local search with an extended neighborhood definition."""
    current_perm = initial_perm
    current_score = score_ordering(current_perm, score_matrix, d)
    while True:
        all_moves = generate_all_possible_moves(current_perm)
        next_perm = None
        for move in all_moves:
            move_score = score_ordering(move, score_matrix, d)
            if move_score > current_score:  # Assuming we want to maximize the score
                next_perm = move
                current_score = move_score
                break  # Exit early if a better move is found
        if next_perm is None:
            break  # No improvement found
        current_perm = next_perm
    return current_perm

def createFullyConnectedGraph(topological_order):
    n = len(topological_order)
    adj_matrix = np.zeros((n, n))

    for i in range(n):
        for j in range(i + 1, n):
            adj_matrix[topological_order[i], topological_order[j]] = 1

    return adj_matrix

def create_graph(g):
    graph = nx.DiGraph()
    num_nodes = g.shape[0]
    graph.add_nodes_from(range(num_nodes))
    for i, j in zip(*np.where(g == 1)):
        graph.add_edge(i, j)
    return graph

class Intersort(BaseModel):
    def __init__(self):
        super().__init__()
        self._adj_matrix = None

    def train(
        self,
        dataset: Dataset,
        log_wandb: bool = False,
        wandb_project: str = "intersort",
        wandb_config_dict: Optional[dict] = None,
        **model_kwargs,
    ):
        data = dataset.tensors[0].numpy()
        gies.np.bool = bool

        if log_wandb:
            wandb_config_dict = wandb_config_dict or {}
            wandb.init(
                project=wandb_project,
                name="Intersort",
                config=wandb_config_dict,
            )

        intervention_mask = dataset.tensors[1].numpy()
        intervention_strings = np.array(
            ["".join(map(str, row)) for row in intervention_mask]
        )
        
        data_envs = []
        intervention_list = []
        for intervention_id in list(set(intervention_strings)):
            intervention_idxs = np.where(intervention_strings == intervention_id)[0]
            data_envs.append(data[intervention_idxs])
            intervention_list.append(
                list(np.where(1 - intervention_mask[intervention_idxs[0]])[0])
            )
        start = time.time()
        obs_indices = np.where(intervention_mask.all(axis=1))[0]
        data_obs = data[obs_indices, :]
        d = data_obs.shape[1]
        score_matrix = _compute_scores(data, intervention_mask, d)
        lmbda = {10: 0.3, 30: 0.3}
        self._adj_matrix = _sort_ranking(score_matrix, data, intervention_mask, lmbda[d])
        topological_order_sortranking = list(nx.topological_sort(create_graph(self._adj_matrix))) 
        candidate = local_search_extended(topological_order_sortranking, score_matrix, d)
        self._adj_matrix = createFullyConnectedGraph(candidate) 
        
        self._train_runtime_in_sec = time.time() - start

    def get_adjacency_matrix(self, threshold: bool = True) -> np.ndarray:
        return self._adj_matrix
