import torch
import numpy as np
from models.base import Stein_hess, cam_pruning, TabPFN_pruning, xgb_pruning, rf_pruning, mlp_pruning
from utils import (
    full_DAG,
    pre_pruning_with_parents_score,
    add_edge_with_parents_score,
    get_parents_score,
)


def compute_top_order(X, eta_G, eta_H, dispersion="mean"):
    n, d = X.shape
    full_X = X
    order = []
    active_nodes = list(range(d))
    for i in range(d - 1):
        H = Stein_hess(X, eta_G, eta_H)
        if dispersion == "mean":
            l = int(H.mean(dim=0).argmax())
        else:
            raise Exception("Unknown dispersion criterion")

        order.append(active_nodes[l])
        active_nodes.pop(l)

        X = torch.hstack([X[:, 0:l], X[:, l + 1 :]])
    order.append(active_nodes[0])
    order.reverse()

    active_nodes = list(range(d))
    full_H = Stein_hess(full_X, eta_G, eta_H).mean(dim=0)
    parents_score = np.zeros((d, d))
    for i in range(d):
        curr_X = torch.hstack([full_X[:, 0:i], full_X[:, i + 1 :]])
        curr_H = Stein_hess(curr_X, eta_G, eta_H).mean(dim=0)
        parents_score[i] = get_parents_score(curr_H, full_H, i)

    return order, parents_score


def train_caps(train_set, args):
    train_set_tensor = torch.Tensor(train_set[:, 1:])

    order, parents_score = compute_top_order(
        train_set_tensor, eta_G=0.001, eta_H=0.001, dispersion="mean"
    )

    if args.pre_pruning:
        init_dag = pre_pruning_with_parents_score(
            full_DAG(order), parents_score, args.lambda1
        )
    else:
        init_dag = full_DAG(order)
    
    train_set_numpy = train_set_tensor.numpy()
    
    if args.model == 'OURS':
        dag, tau_stats = TabPFN_pruning(init_dag, train_set_numpy, args)
    elif args.model == 'CaPS':
        # Pruning method selection
        if args.pruning_method == 'tabpfn':
            dag, tau_stats = TabPFN_pruning(init_dag, train_set_numpy, args)
        elif args.pruning_method == 'cam':
            dag, _ = cam_pruning(init_dag, train_set_numpy, 0.001)
            tau_stats = {'mode': 'cam'}
        elif args.pruning_method == 'xgb':
            dag, _, tau_stats = xgb_pruning(init_dag, train_set_numpy, args)
        elif args.pruning_method == 'rf':
            dag, _, tau_stats = rf_pruning(init_dag, train_set_numpy, args)
        elif args.pruning_method == 'mlp':
            dag, _, tau_stats = mlp_pruning(init_dag, train_set_numpy, args)
        else:
            raise ValueError(f"Pruning method '{args.pruning_method}' is not supported.")
    else:
        tau_stats = {'mode': 'caps'}
        
    if args.add_edge:
        dag = add_edge_with_parents_score(dag, parents_score, args.lambda2)

    return dag, order, tau_stats
