import time
from typing import Optional

import numpy as np
from torch.utils.data import Dataset
import scipy
import networkx as nx
import gies
from numpy import linalg
import wandb
#from causalscbench.evaluation.statistical_evaluation import Evaluator

from .base._base_model import BaseModel
import numpy as np
import pandas as pd

def causal_tail_coeff(v1, v2, k=None, to_rank=True, both_tails=True):
    n = len(v1)
    if k is None:
        k = int(n ** 0.4)

    if k <= 1 or k >= n:
        raise ValueError("k must be greater than 1 and smaller than n.")

    if to_rank:
        r1 = v1.rank(method='first').values
        r2 = v2.rank(method='first').values
    else:
        r1 = v1
        r2 = v2

    if both_tails:
        k = (k // 2) * 2
        return 1 / (k * n) * np.sum(2 * np.abs(r2[(r1 > n - k / 2) | (r1 <= k / 2)] - (n + 1) / 2))
    else:
        return 1 / (k * n) * np.sum(r2[r1 > n - k])

def causal_tail_matrix(dat, k=None, both_tails=True):
    n, p = dat.shape
    if k is None:
        k = int(n ** 0.4)

    ranked_dat = dat.rank(method='first')
    
    causal_mat = np.empty((p, p))
    for j in range(p):
        for i in range(p):
            if i == j:
                causal_mat[i, j] = np.nan
            else:
                causal_mat[i, j] = causal_tail_coeff(ranked_dat.iloc[:, i], ranked_dat.iloc[:, j], k, to_rank=False, both_tails=both_tails)
    return causal_mat

def ease(dat, k=None, both_tails=True):
    if k is None:
        k = int(len(dat) ** 0.4)

    n, d = dat.shape

    causal_mat = causal_tail_matrix(dat, k, both_tails)

    current_order = [np.nanargmin(np.nanmax(causal_mat, axis=0))]
    for k in range(2, d + 1):
        causal_mat[current_order[-1], :] = np.nan
        avail = list(set(range(d)) - set(current_order))

        if k < d:
            add = avail[np.nanargmin(np.nanmax(causal_mat[:, avail], axis=0))]
        else:
            add = avail[0]
        current_order.append(add)

    order = current_order
    return order


def createFullyConnectedGraph(topological_order):
    n = len(topological_order)
    adj_matrix = np.zeros((n, n))

    for i in range(n):
        for j in range(i + 1, n):
            adj_matrix[topological_order[i], topological_order[j]] = 1

    return adj_matrix

class EASE(BaseModel):
    def __init__(self):
        super().__init__()
        self._adj_matrix = None

    def train(
        self,
        dataset: Dataset,
        log_wandb: bool = False,
        wandb_project: str = "ease",
        wandb_config_dict: Optional[dict] = None,
        **model_kwargs,
    ):
        data = dataset.tensors[0].numpy()
        gies.np.bool = bool

        if log_wandb:
            wandb_config_dict = wandb_config_dict or {}
            wandb.init(
                project=wandb_project,
                name="ease",
                config=wandb_config_dict,
            )

        
        start = time.time()
        
        ease_ordering = ease(pd.DataFrame(data))
        self._adj_matrix = createFullyConnectedGraph(ease_ordering)
        
        self._train_runtime_in_sec = time.time() - start



    def get_adjacency_matrix(self, threshold: bool = True) -> np.ndarray:
        return self._adj_matrix
