import torch 
import torch.nn as nn 
from torch.distributions import MultivariateNormal

from ..data.data_class import TrainDataSet, TestDataSet

from numpy.random import default_rng


def generate_full_rank_matrix(rows, cols, seed=None):
    if seed is not None:
        torch.manual_seed(seed)

    assert rows > 0 and cols > 0, "Matrix dimensions must be positive"
    min_dim = min(rows, cols)

    # in 1D case we just return a random nonzero matrix
    if min_dim == 1:
        M = torch.randn(rows, cols)
        # try to avoid almost all 0 case
        while torch.allclose(M, torch.zeros_like(M), atol=1e-2):
            M = torch.randn(rows, cols)
        return M

    A = torch.randn(rows, min_dim)
    B = torch.randn(cols, min_dim)

    # get Q element of QR decomposition
    Q1, _ = torch.linalg.qr(A)
    Q2, _ = torch.linalg.qr(B)

    full_rank_matrix = Q1 @ Q2.T

    return full_rank_matrix

def nonlinear_func(input_dim: int, output_dim: int, n_layers: int, hidden_dim: int) -> nn.Module:
    layers = []

    for i in range(n_layers):
        if i == 0:
            layers.append(nn.Linear(input_dim, hidden_dim))
        else:
            layers.append(nn.Linear(hidden_dim, hidden_dim))

        layers.append(torch.nn.LeakyReLU(negative_slope=0.2))

    layers.append(nn.Linear(hidden_dim, output_dim))
    return nn.Sequential(*layers)


def generate_causal_effect(dimZ, dimY, causal_effect):

    if causal_effect == "linear":
        l = nn.Linear(dimZ, dimY)
        weights = generate_full_rank_matrix(dimY, dimZ)
        l.weight = nn.Parameter(weights)

    elif causal_effect == "nonlinear":
        l = nn.Sequential(
            nn.Linear(dimZ, 16),
            nn.Tanh(),
            nn.Linear(16, dimY),
        )

    else:
        raise ValueError("wrong causal_effect.")
    return l

def sample_covariance_matrix(dim, device='cpu'):
    A = torch.rand(dim, dim, device=device)
    cov = A @ A.T
    return cov + torch.eye(dim, device=device) 


DIM_A = 12
DIM_Z = 10
DIM_X = 16
DIM_Y = 1

INDEP_LATENTS = False


CONFOUNDING_STRENGTH = 0.8


MEAN_FUNC = nonlinear_func(DIM_A, DIM_Z, 2, 16)
VAR_FUNC = nonlinear_func(DIM_A, DIM_Z, 2, 16)
TREATMENT_MIXING = nonlinear_func(DIM_Z, DIM_X, 2, 16)
STRUCTURAL_FUNC = generate_causal_effect(DIM_Z, DIM_Y, "nonlinear")
CONFOUNDING_FUNC = nn.Linear(DIM_Z, DIM_Y, bias=False)

if INDEP_LATENTS:
    BASE_COV = torch.eye(DIM_Z)
else:
    BASE_COV = sample_covariance_matrix(DIM_Z)


def generate_train_imca(data_size: int,
                        rho: float,
                        rand_seed: int = 42,
                        ) -> TrainDataSet:
    

    rng = default_rng(seed=rand_seed)

    instrument = torch.randn(data_size, DIM_A)

    confounder = MultivariateNormal(torch.zeros(DIM_Z), 0.01 * BASE_COV).sample((data_size,))

    latent_treatment = MEAN_FUNC(instrument) + VAR_FUNC(instrument) * confounder 
    treatment = TREATMENT_MIXING(latent_treatment)

    structural = STRUCTURAL_FUNC(latent_treatment)
    outcome_noise = rho * CONFOUNDING_FUNC(confounder) + 0.01 * torch.randn(data_size, DIM_Y)
    outcome = structural + outcome_noise

    return TrainDataSet(treatment=treatment.detach().numpy(),
                        instrumental=instrument.numpy(),
                        covariate=None,
                        structural=structural.detach().numpy(),
                        outcome=outcome.detach().numpy())




def generate_test_imca() -> TestDataSet:
    torch.manual_seed(0)
    data_size = 1000
    instrument = torch.randn(data_size, DIM_A)
    confounder = MultivariateNormal(torch.zeros(DIM_Z), 0.01 * torch.eye(DIM_Z)).sample((data_size,))
    latent_treatment = MEAN_FUNC(instrument) + VAR_FUNC(instrument) * confounder
    treatment = TREATMENT_MIXING(latent_treatment)
    structural = STRUCTURAL_FUNC(latent_treatment)

    return TestDataSet(treatment=treatment.detach().numpy(),
                       covariate=None,
                       structural=structural.detach().numpy())


