import numpy as np
from itertools import product
from filelock import FileLock
import pathlib

import torch 

from ..data.data_class import TrainDataSet, TestDataSet

DATA_PATH = pathlib.Path(__file__).resolve().parent.parent.parent.joinpath("data/")


def posX_scale_target(latents):
    return latents[:, 4] + latents[:, 2]

def confounding_fn(x):
    return x - 0.5

def generate_full_rank_matrix(rows, cols, seed=None):
    if seed is not None:
        torch.manual_seed(seed)

    assert rows > 0 and cols > 0, "Matrix dimensions must be positive"
    min_dim = min(rows, cols)

    # in 1D case we just return a random nonzero matrix
    if min_dim == 1:
        M = torch.randn(rows, cols)
        # try to avoid almost all 0 case
        while torch.allclose(M, torch.zeros_like(M), atol=1e-2):
            M = torch.randn(rows, cols)
        return M

    A = torch.randn(rows, min_dim)
    B = torch.randn(cols, min_dim)

    # get Q element of QR decomposition
    Q1, _ = torch.linalg.qr(A)
    Q2, _ = torch.linalg.qr(B)

    full_rank_matrix = Q1 @ Q2.T

    return full_rank_matrix

class Linear_mixing:
    def __init__(self, weights, device='cpu'):
        self.weights = weights.to(device)

    def __call__(self, latents):
        return torch.matmul(latents, self.weights)

def _get_images_from_latents(Z, latents_values, imgs):
        """
        Given a batch of latent vectors Z, return the corresponding images from the loader
        using exact matching in the filtered latents.
        """
        X_list = []
        for z in Z:
            matches = (latents_values == z).all(dim=1)
            idx = torch.nonzero(matches, as_tuple=False)
            if idx.numel() == 0:
                raise ValueError(f"No matching latent found for {z}")
            X_list.append(imgs[idx[0,0]])
        return np.stack(X_list)


CONFOUNDING_FUNC = lambda x: x - 0.5
INSTRUMENT_MIXING = Linear_mixing(generate_full_rank_matrix(5, 8), device="cpu")
TARGET_FN = posX_scale_target

def generate_dsprites(rho: float,
                      train: bool,
                      data_size: int,
                      rand_seed: int = 42) -> TrainDataSet:
    
    torch.manual_seed(rand_seed)
    
    with FileLock("./data.lock"):
        dataset_zip = np.load(DATA_PATH.joinpath("dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"),
                              allow_pickle=True, encoding="bytes")
        weights = np.load(DATA_PATH.joinpath("dsprite_mat.npy"))
    
    imgs = dataset_zip['imgs']
    latents_values = dataset_zip['latents_values']
    metadata = dataset_zip['metadata'][()]

    
    latents_values = torch.from_numpy(latents_values).float()

    uniques = [torch.unique(latents_values[:, i]) for i in range(latents_values.shape[1])]
    steps = [(u[1] - u[0])/2 if len(u) > 1 else u[0]/2 for u in uniques]
    ranges = [(u[0]-s, u[-1]+s) for u, s in zip(uniques, steps)]

    # Sample intermediate Q uniformly in ranges
    Q = torch.empty((data_size, latents_values.shape[1]))

    for i, (low, high) in enumerate(ranges):
            Q[:, i] = torch.empty(data_size).uniform_(low, high)

    # compute Z from Q as the closest latent value for each component
    Z = torch.empty_like(Q)
    for i in range(Q.shape[1]):
        vals = uniques[i]
        # For each sample, find the closest value in vals
        idx = torch.argmin(torch.abs(vals.unsqueeze(0) - Q[:, i].unsqueeze(1)), dim=1)
        Z[:, i] = vals[idx]


    treatment = _get_images_from_latents(Z, latents_values, imgs)

    mask = [0,1,2,3,4]

    instrument = INSTRUMENT_MIXING(Z[:, mask])

    structural = TARGET_FN(Z).unsqueeze(1)

    if not train:
         return TestDataSet(treatment=treatment.reshape((data_size, 64*64)),
                            covariate=None,
                            structural=structural.numpy(),)

    outcome = structural + rho * CONFOUNDING_FUNC(Z[:, 5]).unsqueeze(1) + 0.01 * torch.randn((data_size, 1))

    return TrainDataSet(treatment=treatment.reshape((data_size, 64*64)),
                        covariate=None,
                        instrumental=instrument.numpy(),
                        outcome=outcome.numpy(),
                        structural=structural.numpy(),)


def generate_train_dsprites_latent(data_size: int,
                            rho: float,
                            rand_seed: int = 42) -> TrainDataSet:
    return generate_dsprites(rho=rho, train=True, data_size=data_size, rand_seed=rand_seed)


def generate_test_dsprites_latent() -> TestDataSet:
    return generate_dsprites(rho=0.0, train=False, data_size=1000, rand_seed=42)
