from typing import List, Tuple

import torch
import torch.nn as nn
import numpy as np
from diffintersort import *
from torch.optim import Adam 
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


def causal_discovery(
        X: np.array,
        interventions: List[int],
        score_matrix,
        eps,  
        config,
        init_ordering = None,
        lambda_int = 100.0,
    ) -> List[Tuple]:
    mask = np.ones(X.shape, dtype=np.float32)
    X = X.astype(np.float32)

    d = X.shape[1]

    for i in range(X.shape[0]):  
        if interventions[i] != -1:
            mask[i, interventions[i]] = 0  

    W_est, p_est = causal_disco_training(X, lambda1=0.01, lambda_int=lambda_int, d=d, score_matrix=score_matrix, interventions=interventions, intervention_mask=mask, init_ordering=init_ordering, scaling=config[d]["scaling"], n_iter=2000, lr_int=config[d]["lr"], t_sinkhorn = 0.05, n_iter_sinkhorn=500, eps=eps)
    parents, children = np.nonzero(np.abs(W_est) > 0)
    W_est = np.zeros((d, d))
    edges = set()
    for i in range(len(parents)):
        if p_est[parents[i]] < p_est[children[i]]:
            edges.add((parents[i], children[i]))
            W_est[parents[i], children[i]] = 1
    return list(edges), np.abs(W_est) > 0.0, np.argsort(p_est)


class Linear(nn.Module):
    def __init__(self, d, lambda1, lambda_int, intervention_mask=None,):  
        super().__init__()
        self.d = d
        self.lambda1 = lambda1
        self.lambda_int = lambda_int
        self.fc = torch.nn.Linear(self.d, self.d, bias=False)
        self.intervention_mask = intervention_mask
        self.perm_matrix = None

    def postprocess_A(self):
        A = self.fc.weight.T 
        A_est = torch.where(torch.abs(A) > 0.05, A, torch.tensor(0.0, device=A.device))  
        return A_est.detach().cpu().numpy()

    def l1_reg(self):
        A = self.fc.weight
        return torch.sum(torch.abs(A))

    def forward(self, X, perm_matrix, batch_indices):
        if perm_matrix.dim() == 3 and perm_matrix.shape[0] == 1:
            perm_matrix = perm_matrix.squeeze(0)
        
        if perm_matrix is not None:
            A = self.fc.weight * perm_matrix.T
            self.perm_matrix = perm_matrix
        else:
            A = self.fc.weight
        input_batch = X[batch_indices, :] 
        output = input_batch.mm(A.T) 
        if self.intervention_mask is not None:
            output = output * self.intervention_mask[batch_indices, :]
        return output


def constrain_loss(d, n_iter_sinkhorn, t_sinkhorn, p, score_matrix_torch):
    sig_p = (torch.sigmoid(p) * 2) - 1
    perm_matrix, perm = compute_perm_matrix(sig_p, d, sinkhorn_n_iter=n_iter_sinkhorn, t=t_sinkhorn)
    score = perm_matrix * score_matrix_torch
    loss = -torch.sum(torch.mean(score, dim=0))
    full_lower = torch.ones(1, int((d - 1) * d / 2)).to(p.device)
    full_lower = fill_triangular(full_lower, d, upper=True)

    return loss, perm_matrix

def causal_disco_training(X, lambda1, lambda_int, d, score_matrix, interventions, init_ordering=None, scaling=0.1, n_iter=100, lr_int=0.001, n_iter_sinkhorn=300, t_sinkhorn=0.5, eps=0.3, intervention_mask=None,):
    p_scale = 0.001
    p = p_scale * torch.randn((d), device=device)
    if init_ordering is not None:
        # Sort init_ordering to get indices that would sort it
        _, indices = torch.sort(torch.tensor(init_ordering), descending=False)
        # Sort p in descending order
        p_sorted, _ = torch.sort(p, descending=False)
        # Reorder p_sorted according to the indices from the sorted init_ordering
        p = p_sorted[indices]
    p.requires_grad = True
    p_opt = Adam([p],
                 lr=lr_int,
                 betas=(0.9, 0.99),
                 )

    score_matrix[score_matrix > 0.0] -= eps
    transitive = transitive_closure(score_matrix > 0.0, depth=d)
    score_matrix[transitive > 0.0] = scaling * d
    score_matrix_torch = torch.tensor(score_matrix).to(device)
    X = torch.tensor(X, dtype=torch.float32).to(device)  
    
    if intervention_mask is not None:
        intervention_mask = torch.tensor(intervention_mask, dtype=torch.float32).to(device)

    score_matrix_torch = torch.tensor(score_matrix, dtype=torch.float32).to(device)
    envs = len(np.unique(interventions))
    N = X.shape[0]

    model = Linear(d, lambda1=lambda1, lambda_int=lambda_int, intervention_mask=intervention_mask)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    early_stop = 100
    best_loss = (float('inf'), float('inf')) 

    n_batches = 3

    batch_size = max(N // n_batches, 1)  
    lambda_ = 0.5


    for i in range(n_iter // 2):
        a = 0
        total_loss = 0
        indices = torch.randperm(N)
            
        while a < N:
            b = min(N, a + batch_size)
            batch_indices = indices[a:b] 
            optimizer.zero_grad()
            output = model(X, (torch.ones(d, d) - torch.eye(d)).to(device), batch_indices)

            input_batch = X[batch_indices]
            inter_batch = interventions[batch_indices]
            inter_batch = torch.tensor(inter_batch, dtype=torch.long, device=device)  # Ensure it's a tensor on the correct device
            input_batch = input_batch * intervention_mask[batch_indices, :]
            unique_interventions = torch.unique(inter_batch)
            mse_env_list = []
            non_baseline_envs = [e for e in unique_interventions if e != -1]

            # Compute baseline loss
            mask = inter_batch == -1
            inputs_env = input_batch[mask]
            outputs_env = output[mask]
            if inputs_env.numel() > 0:
                mse_baseline = torch.mean(torch.abs(inputs_env - outputs_env), axis=0)
                mse_baseline *= (1 - lambda_ + (lambda_ / envs))
                mse_env_list.append(mse_baseline)

            # Compute loss for other environments
            for intervention in non_baseline_envs:
                mask = inter_batch == intervention
                inputs_env = input_batch[mask]
                outputs_env = output[mask]
                if inputs_env.numel() > 0:
                    mse_env = torch.mean(torch.abs(inputs_env - outputs_env), axis=0) #- mse_baseline
                    mse_env *= lambda_ / envs
                    mse_env_list.append(mse_env)

            if mse_env_list:
                loss_mse = torch.sum(torch.stack(mse_env_list), axis = 0)
                loss_mse = torch.mean(loss_mse)

            loss = (10) * loss_mse + lambda1 * model.l1_reg() 
            loss.backward()

            optimizer.step()
            

            a += batch_size
            total_loss += loss.item()
        if i % 10 == 0:
            print("Epoch: {}. Loss = {:.3f}".format(i, total_loss))
    A = model.postprocess_A() 
    print("Number of proposed edges is = {}".format(np.count_nonzero(A)))
    for i in range(n_iter):
        a = 0
        total_loss = 0
        total_mse = 0
        indices = torch.randperm(N)
        p_opt.zero_grad()

        if  lambda_int <= 0.0:
            perm_matrix = torch.ones((d, d), dtype=p.dtype).to(p.device)
            constraint_loss =  torch.tensor(0)
        else:
            constraint_loss, perm_matrix = constrain_loss(d, n_iter_sinkhorn, t_sinkhorn, p, score_matrix_torch)
        
        while a < N:
            b = min(N, a + batch_size)
            batch_indices = indices[a:b]
            optimizer.zero_grad()
            p_opt.zero_grad()
            output = model(X, perm_matrix, batch_indices)

            input_batch = X[batch_indices]
            inter_batch = interventions[batch_indices]
            inter_batch = torch.tensor(inter_batch, dtype=torch.long, device=device)  # Ensure it's a tensor on the correct device
            inter_batch = torch.tensor(inter_batch, dtype=torch.long, device=device)  # Ensure it's a tensor on the correct device
            input_batch = input_batch * intervention_mask[batch_indices, :]
            
            unique_interventions = torch.unique(inter_batch)
            mse_env_list = []
            non_baseline_envs = [e for e in unique_interventions if e != -1]

            # Compute baseline loss
            mask = inter_batch == -1
            inputs_env = input_batch[mask]
            outputs_env = output[mask]
            if inputs_env.numel() > 0:
                mse_baseline = torch.mean(torch.abs(inputs_env - outputs_env), axis=0)
                mse_baseline *= (1 - lambda_ + (lambda_ / envs))
                mse_env_list.append(mse_baseline)

            # Compute loss for other environments
            for intervention in non_baseline_envs:
                mask = inter_batch == intervention
                inputs_env = input_batch[mask]
                outputs_env = output[mask]
                if inputs_env.numel() > 0:
                    mse_env = torch.mean(torch.abs(inputs_env - outputs_env), axis=0) 
                    mse_env *= lambda_ / envs
                    mse_env_list.append(mse_env)

            if mse_env_list:
                loss_mse = torch.sum(torch.stack(mse_env_list), axis = 0) 
                loss_mse = torch.mean(loss_mse)
            loss = (10) * loss_mse + lambda1 * model.l1_reg() 
            if lambda_int > 0.0:
                loss += lambda_int * constraint_loss 
            loss.backward(retain_graph=True)

            optimizer.step()
            p_opt.step()
            total_mse += loss_mse.item()
            a += batch_size
            total_loss += loss.item()


        if i % 10 == 0:
            print("Epoch: {}. Loss = {:.3f}".format(i, total_loss))
            print("Epoch: {}. MSE Loss = {:.3f}".format(i, total_mse))
            print("Epoch: {}. Cons Loss = {:.3f}".format(i, constraint_loss))

        if constraint_loss.item() < best_loss[0] or (constraint_loss.item() <= best_loss[0] and total_mse < best_loss[1]):
            early_stop = 100
            best_loss = (constraint_loss.item(), total_mse)
            best_A = A = model.postprocess_A()
            best_p = p.detach().cpu().numpy()
        else:
            early_stop -= 1

        if early_stop <= 0 and i > 500:
            break

    print(model.fc.weight.max(), model.fc.weight.min())
    print(model.fc.weight)
    A = model.postprocess_A() 
    print("Number of proposed edges is = {}".format(np.count_nonzero(A)))

    return best_A, best_p
