from ._diffintersort_cons import _compute_scores
from .base._base_model import BaseModel
import time
import numpy as np
from torch.utils.data import Dataset
import scipy
import networkx as nx
import gies
from numpy import linalg
import wandb
import torch.nn as nn
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
from ._diffintersort_cons import *

class DiffIntersortNoCons(BaseModel):
    def __init__(self):
        super().__init__()
        self._adj_matrix = None
        self.config = {
            3: {
                "lr": 0.05,
            },
            10: {
                "lr": 0.005,
                "scaling": 0.1,
            },
            30: {
                "lr": 0.001,
                "scaling": 0.5,
            },
            100: {
                "lr": 0.0001,
                "scaling": 1.0,
            },
            500: {
                "lr": 0.00005,
                "scaling": 10.0,
            },
            1000: {
                "lr": 0.0005,
                "scaling": 1.0,
            },
            2000: {
                "lr": 0.0001,
                "scaling": 1.0,
            }
        }


    def train(
        self,
        dataset: Dataset,
        log_wandb: bool = False,
        wandb_project: str = "diffintersort_no_cons",
        wandb_config_dict: Optional[dict] = None,
        **model_kwargs,
    ):
        data = dataset.tensors[0].numpy()
        gies.np.bool = bool

        if log_wandb:
            wandb_config_dict = wandb_config_dict or {}
            wandb.init(
                project=wandb_project,
                name="DiffIntersortNoCons",
                config=wandb_config_dict,
            )

        intervention_mask = dataset.tensors[1].numpy()
        intervention_strings = np.array(
            ["".join(map(str, row)) for row in intervention_mask]
        )
        
        data_envs = []
        intervention_list = []
        for intervention_id in list(set(intervention_strings)):
            intervention_idxs = np.where(intervention_strings == intervention_id)[0]
            data_envs.append(data[intervention_idxs])
            intervention_list.append(
                list(np.where(1 - intervention_mask[intervention_idxs[0]])[0])
            )
        start = time.time()
        obs_indices = np.where(intervention_mask.all(axis=1))[0]
        data_obs = data[obs_indices, :]
        d = data_obs.shape[1]
        score_matrix = _compute_scores(data, intervention_mask, d)
        lmbda = {10: 0.3, 30: 0.3, 100: 0.3}
        
        score_matrix[score_matrix > lmbda[d]] =  self.config[d]["scaling"] * d 
        score_matrix[(score_matrix < lmbda[d]) & (score_matrix > 0.0)] = 0.1
            
        causal_disco_graph, self._adj_matrix = causal_discovery(data, intervention_mask, score_matrix, lmbda[d], self.config, init_ordering=None, lambda_int=0.0)
        self._train_runtime_in_sec = time.time() - start

    def get_adjacency_matrix(self, threshold: bool = True) -> np.ndarray:
        return self._adj_matrix
