import itertools
from typing import Callable, List, Dict, Any

import torch
from tqdm import tqdm, trange

from src.evaluation import evaluate_mse_pearson
from src.regression import torchOLS
from src.anchorRegression import AnchorRegression
from src.irm_models import InvariantRiskMinimization  # or whatever class you use
from src.irm_models import InvariantCausalPrediction

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.linear_model import LinearRegression, Ridge, RidgeCV

from src.mlp import MLP, train_and_evaluate_mlp_attention, train_and_evaluate_mlp_attention_MSK_new_loss
import torch.nn as nn
import torch.optim as optim

import numpy as np
import os
import concurrent.futures

#torch.inverse(torch.ones((1, 1), device="cuda:0"))



def cross_validate_imputation_method(
        method_fn: Callable,
        X: torch.Tensor,
        Z: torch.Tensor,
        Y: torch.Tensor,
        hyperparam_grid: Dict[str, List[Any]]
) -> torch.Tensor:
    """
    Generic cross-validation for imputation methods based on validation loss.

    Args:
        method_fn (Callable): One of weights_ridge_regression, iterative_reweighted_ridge_regression, lasso_regression.
        X (torch.Tensor): [n_samples x n_features] input matrix.
        Z (torch.Tensor): [output_dim x n_samples] validation target matrix.
        Y (torch.Tensor): [output_dim x n_samples] training target matrix.
        hyperparam_grid (Dict[str, List[Any]]): Dictionary of hyperparameter lists to search over.

    Returns:
        torch.Tensor: Final W computed with best hyperparameters using Y.
    """
    best_params = None
    best_val_loss = float("inf")
    best_W_Z = None

    # Create list of all hyperparameter combinations
    keys, values = zip(*hyperparam_grid.items())
    combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

    print(f"\n[CV] Trying {len(combinations)} combinations for {method_fn.__name__}...")

    for param_set in tqdm(combinations, desc=f"CV: {method_fn.__name__}", unit="combo"):
        # Compute weights on validation target Z
        W_Z = method_fn(X, Z, **param_set)

        # Predict Z from W_Z
        Z_pred = W_Z @ X
        val_loss = torch.nn.functional.mse_loss(Z, Z_pred).item()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_params = param_set
            best_W_Z = W_Z

    # Compute final W using best hyperparameters on Y
    final_W = method_fn(X, Y, **best_params)
    return final_W


def weights_ridge_regression(X: torch.Tensor, Y: torch.Tensor, lambda_: float) -> torch.Tensor:
    """
    Solve W = Y X^T (X X^T + lambda I)^(-1) using a stable linear-system solver.
    Returns W of shape (100, 1400).

    Arguments:
        Y: [100 x 10] target matrix
        X: [1400 x 10] input matrix
        lambda_: scalar ridge regularization parameter

    Returns:
        W_ridge: [100 x 1400]
    """
    # 1) Construct M = X X^T + lambda * I
    #    X X^T is [1400 x 1400].
    M = X @ X.t()  # 1400 x 1400
    n = M.size(0)
    # Add lambda on the diagonal for ridge
    M = M + lambda_ * torch.eye(n, dtype=M.dtype, device=M.device)

    # 2) Construct B = Y X^T
    #    Y is [100 x 10], X^T is [10 x 1400], so B is [100 x 1400].
    B = Y @ X.t()  # [100 x 1400]

    # 3) We want W M = B => M W^T = B^T => solve M W^T = B^T
    #    so W^T = solve(M, B^T). Then W = (W^T)^T.
    # Use torch.linalg.solve for a stable solution:
    Wt = torch.linalg.solve(M, B.t())  # solves M * W^T = B^T
    W = Wt.t()  # [100 x 1400]

    return W


def iterative_reweighted_ridge_regression(
        X: torch.Tensor,
        Y: torch.Tensor,
        lam: float = 1e-2,
        alpha: float = 1e-3,
        max_iter: int = 10,
        init_mode: str = "ridge"
) -> torch.Tensor:
    r"""
    Iterative Reweighted Ridge Heuristic for (approximately) sparse W.

    Minimizes:
        \|W X - Y\|_F^2 + \lambda \sum_{i,j} (W_{i,j})^2 / (\alpha + |W_{i,j}^{(k)}|)

    Args:
        Y: Target matrix, shape [n_outputs, n_samples]  (e.g.  [100, 10])
        X: Input matrix,  shape [n_features, n_samples] (e.g. [1400, 10])
        lam: Ridge-like regularization parameter
        alpha: Small constant for the reweighting denominator
        max_iter: Number of outer iterations
        init_mode: How to initialize W^(0). Options:
            - "ridge": use standard ridge_regression closed-form
            - "zeros": start with all zeros
            - "ols":   use Y@pinv(X) if feasible
    Returns:
        W: Approximated sparse matrix of shape [n_outputs, n_features] (e.g. [100, 1400])
    """
    # Dimensions:
    n_outputs, n_samples = Y.shape  # e.g. 100 x 10
    n_features = X.shape[0]  # e.g. 1400
    # We'll produce W of shape [n_outputs, n_features] => [100, 1400]

    # --- Step 0: Initialization ---
    if init_mode == "ridge":
        # W_ridge = Y X^T (X X^T + lam I)^(-1)
        # Solve it via linear system to avoid explicit inverse
        X_XT = X @ X.t()  # [1400 x 1400]
        eye_ = torch.eye(n_features, device=X.device, dtype=X.dtype)
        M = X_XT + lam * eye_
        B = Y @ X.t()  # [100 x 1400]
        # Solve M^T W^T = B^T, but M is symmetric => M W^T = B^T
        Wt_init = torch.linalg.solve(M, B.t())  # [1400 x 100]
        W_init = Wt_init.t()  # [100 x 1400]
    elif init_mode == "ols":
        # OLS: W_0 = Y pinv(X) if rank(X) is full or near-full
        X_pinv = torch.linalg.pinv(X)  # [10 x 1400]
        W_init = Y @ X_pinv  # [100 x 1400]
    else:
        # "zeros" or fallback
        W_init = torch.zeros(n_outputs, n_features, device=X.device, dtype=X.dtype)

    W = W_init.clone()  # Working copy for iteration

    # Precompute for the data-fitting term
    X_XT = X @ X.t()  # shape [1400 x 1400], used in each row subproblem


    # --- Main Iteration Loop ---
    for _ in range(max_iter):
        # For each row i, solve the weighted ridge problem row-wise
        for i in range(n_outputs):
            # Current row w_i is shape [1400]
            w_i = W[i, :]  # This is a view into W


            D_i_diag = 1.0 / (alpha + w_i.abs() + 1e-16)  # +1e-16 for safety
            D_i = torch.diag(D_i_diag)


            M_i = X_XT + lam * D_i  # [1400 x 1400]
            b_i = Y[i, :] @ X.t()  # y_i X^T => shape [1400]


            w_i_new = torch.linalg.solve(M_i, b_i.unsqueeze(-1))  # [1400 x 1]
            w_i[:] = w_i_new.squeeze(-1)  # update in-place


    return W


def lasso_regression(
        X: torch.Tensor,
        Y: torch.Tensor,
        lam: float = 1e-3,
        lr: float = 1e-4,
        max_iter: int = 1000,
        tol: float = 1e-6
) -> torch.Tensor:
    r"""
    Solve the LASSO problem via Proximal Gradient (ISTA):
        min_W  0.5 * ||W X - Y||_F^2 + lambda * ||W||_1

    Args:
        Y: [n_out x n_samples],   target matrix
        X: [n_features x n_samples],   input matrix
        lam:  L1 regularization parameter (lambda)
        lr:   step size (eta) for gradient updates
        max_iter:  maximum number of iterations
        tol:   stop if the update norm is below this threshold

    Returns:
        W: [n_out x n_features] solution
    """
    # Dimensions
    n_out, n_samp = Y.shape
    n_feat = X.shape[0]

    # Initialize W, e.g., zeros
    W = torch.zeros(n_out, n_feat, dtype=X.dtype, device=X.device)

    # Precompute X^T (used in gradient)
    X_t = X.t()  # shape: [n_samples x n_features]

    for i in range(max_iter):
        W_old = W.clone()

        # 1) Gradient of 0.5||W X - Y||_F^2:
        residual = W @ X - Y  # shape [n_out x n_samples]
        grad = residual @ X_t  # shape [n_out x n_feat]

        # 2) Take a gradient descent step:
        Z = W - lr * grad

        # 3) Proximal operator (soft-threshold) for L1:
        W_abs = Z.abs()
        tmp = (W_abs - lr * lam).clamp_min(0.0)  # max(|Z| - lr*lam, 0)
        W = Z.sign() * tmp

        # 4) Check for convergence
        diff = (W - W_old).norm().item()
        if diff < tol:
            break

    return W


def select_best_lambda_reg(
    X_train: torch.Tensor,
    Y_train: torch.Tensor,
    X_val: torch.Tensor,
    Y_val: torch.Tensor,
    W: torch.Tensor,
    lambda_reg_grid: list
) -> float:
    """
    Cross-validates lambda_reg for torchOLS using validation set.

    Returns:
        Best lambda_reg based on validation MSE.
    """
    best_lambda = None
    best_loss = float("inf")

    for lam in lambda_reg_grid:
        Y_pred = torchOLS(X_train, Y_train, X_val, Y_val, W ** 2, lambda_reg=lam)
        loss = torch.nn.functional.mse_loss(Y_val, Y_pred).item()

        if loss < best_loss:
            best_loss = loss
            best_lambda = lam

    return best_lambda

def evaluate_imputations(
        X_train,
        Y_train,
        X_val,
        Y_val,
        X_test_approx,
        Y_test,
        criterion,
        lambda_reg_grid=None,
        ridge_param_grid=None,
        ols_ridge_param_grid=None,
        lasso_param_grid=None,
        attention_param_grid=None, 
        W_true_val=None,
        W_true_test=None
):
    """
    Evaluates different regression methods and computes losses on the test set, using validation-based
    hyperparameter tuning for lambda_reg and method-specific parameters.

    Returns:
        dict: A dictionary containing the losses for each regression method.
    """
    # Defaults for grid search
    if lambda_reg_grid is None:
        lambda_reg_grid = [1e-6]

    if ridge_param_grid is None:
        ridge_param_grid = {"lambda_": [1.0, 0.01, 0.001, 0.0001]}


    if lasso_param_grid is None:
        lasso_param_grid = {
            "lam": [1.0],
            "lr": [1e-6],
            "max_iter": [5000]
        }

    
    losses = {}
    
    # --- Ridge Regression ---
    if X_val is not None:
        W_ridge = cross_validate_imputation_method(
            method_fn=weights_ridge_regression,
            X=X_train, Z=X_val, Y=X_test_approx,
            hyperparam_grid=ridge_param_grid
        )
        best_lambda = select_best_lambda_reg(X_train, Y_train, X_val, Y_val, W_ridge, lambda_reg_grid)
        Y_hat_ridge = torchOLS(X_train, Y_train, X_test_approx, Y_test, W_ridge ** 2, lambda_reg=best_lambda)
    else:
        W_ridge = weights_ridge_regression(X_train, X_test_approx, ridge_param_grid["lambda_"][0])
        Y_hat_ridge = torchOLS(X_train, Y_train, X_test_approx, Y_test, W_ridge ** 2, lambda_reg=lambda_reg_grid[0])
    losses['ridge'] = evaluate_mse_pearson(Y_hat_ridge, Y_test)
    losses['sparse_ridge'] = {'mse': 100, 'pearson_corr': -1}
    W_sparse = None

    # --- Lasso Regression ---
    if X_val is not None:
        W_lasso = cross_validate_imputation_method(
            method_fn=lasso_regression,
            X=X_train, Z=X_val, Y=X_test_approx,
            hyperparam_grid=lasso_param_grid
        )
        best_lambda = select_best_lambda_reg(X_train, Y_train, X_val, Y_val, W_lasso, lambda_reg_grid)
        Y_hat_lasso = torchOLS(X_train, Y_train, X_test_approx, Y_test, W_lasso ** 2, lambda_reg=best_lambda)
    else:
        W_lasso = lasso_regression(X_train, X_test_approx, **{
            k: v[0] for k, v in lasso_param_grid.items()
        })
        Y_hat_lasso = torchOLS(X_train, Y_train, X_test_approx, Y_test, W_lasso ** 2, lambda_reg=lambda_reg_grid[0])
    losses['lasso'] = evaluate_mse_pearson(Y_hat_lasso, Y_test)
    regressor = LinearRegression()
    regressor.fit(X_train.cpu().numpy(), Y_train.cpu().numpy())
    Y_pred_unweighted = torch.tensor(regressor.predict(X_test_approx.cpu().numpy()), device=Y_test.device)
    losses['unweighted'] = evaluate_mse_pearson(Y_pred_unweighted, Y_test)
    

    # Final test performance
    if ols_ridge_param_grid is not None and X_val is not None:
        keys, values = zip(*ols_ridge_param_grid.items())
        combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
        best_val_loss = float("inf")
        best_test_results = float("inf")

        for comb in tqdm(combinations, desc="Ridge CV", leave=False):
            ridge = Ridge(**comb)
            ridge.fit(X_train.cpu().numpy(), Y_train.cpu().numpy())
            prediction = ridge.predict(X_val.cpu().numpy())
            loss = evaluate_mse_pearson(torch.tensor(prediction), Y_val)
                
            if loss['mse'] < best_val_loss:
                best_val_loss = loss['mse']
                prediction = ridge.predict(X_test_approx.cpu().numpy())
                best_test_results = evaluate_mse_pearson(torch.tensor(prediction), Y_test)

        losses['OLS_Ridge'] = best_test_results
    else:
        ridge = Ridge()
        ridge.fit(X_train, Y_train)
        prediction = ridge.predict(X_test_approx)
        losses['OLS_Ridge'] = evaluate_mse_pearson(torch.tensor(prediction), Y_test)


    # --- Ground Truth Weights ---
    if W_true_test is not None:
        if X_val is not None:
            best_lambda = select_best_lambda_reg(X_train, Y_train, X_val, Y_val, W_true_val, lambda_reg_grid)
            Y_hat_true = torchOLS(X_train, Y_train, X_test_approx, Y_test, W_true_test ** 2, lambda_reg=best_lambda)
        else:
            Y_hat_true = torchOLS(X_train, Y_train, X_test_approx, Y_test, W_true_test ** 2, lambda_reg=lambda_reg_grid[0])
        losses['ground_truth'] = evaluate_mse_pearson(Y_hat_true, Y_test)

    # --- Anchor Regression (direct prediction, no W, uses its own fit/predict) ---
    anchor_param_grid = {"lamb": [0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]}
    best_val_loss = float("inf")
    best_model = None

    # Convert tensors to NumPy
    X_train_np = X_train.cpu().numpy()
    Y_train_np = Y_train.cpu().numpy()
    X_val_np = X_val.cpu().numpy() if X_val is not None else None
    Y_val_np = Y_val.cpu().numpy() if X_val is not None else None
    X_test_np = X_test_approx.cpu().numpy()

    # Construct anchor from first PC of X_train
    pca = PCA(n_components=1)
    A_anchor = pca.fit_transform(X_train_np)

    if X_val is not None:
        for lamb in tqdm(anchor_param_grid["lamb"], desc="Anchor Regression CV", leave=False):
            model = AnchorRegression(
                lamb=lamb,
                fit_intercept=True,
                normalize=True,
                copy_X=True
            )

            model.fit(X_train_np, Y_train_np, A=A_anchor)

            Y_val_pred = model.predict(X_val_np)
            val_loss = torch.nn.functional.mse_loss(
                torch.tensor(Y_val_pred, device=Y_val.device), Y_val
            ).item()

            tqdm.write(f"lambda={lamb:.4f} | val_loss={val_loss:.4f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = model

        # Predict on test set
        Y_test_pred = best_model.predict(X_test_np)
        losses['anchor'] = evaluate_mse_pearson(torch.tensor(Y_test_pred, device=Y_test.device), Y_test)

    else:
        model = AnchorRegression(
            lamb=anchor_param_grid["lamb"][0],
            fit_intercept=True,
            normalize=True,
            copy_X=True
        )
        model.fit(X_train_np, Y_train_np, A=A_anchor)
        Y_test_pred = model.predict(X_test_np)
        losses['anchor'] = evaluate_mse_pearson(torch.tensor(Y_test_pred, device=Y_test.device), Y_test)


    anchor = np.random.RandomState(42).choice(X_train.shape[1])
    A_anchor = X_train.cpu().numpy()[:,[anchor]]
    X_train_np = X_train.cpu().numpy()[:,[i for i in range(X_train.shape[1]) if i != anchor]]
    X_val_np = X_val.cpu().numpy()[:,[i for i in range(X_train.shape[1]) if i != anchor]]
    X_test_np = X_test_approx.cpu().numpy()[:,[i for i in range(X_train.shape[1]) if i != anchor]]
    
    best_val_loss = float("inf")
    best_model = None
    if X_val is not None:
        for lamb in tqdm(anchor_param_grid["lamb"], desc="Anchor Regression CV", leave=False):
            model = AnchorRegression(
                lamb=lamb,
                fit_intercept=True,
                normalize=True,
                copy_X=True
            )

            model.fit(X_train_np, Y_train_np, A=A_anchor)

            Y_val_pred = model.predict(X_val_np)
            val_loss = torch.nn.functional.mse_loss(
                torch.tensor(Y_val_pred, device=Y_val.device), Y_val
            ).item()

            tqdm.write(f"lambda={lamb:.4f} | val_loss={val_loss:.4f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = model

        # Predict on test set
        Y_test_pred = best_model.predict(X_test_np)
        losses['anchor_new'] = evaluate_mse_pearson(torch.tensor(Y_test_pred, device=Y_test.device), Y_test)

    else:
        model = AnchorRegression(
            lamb=anchor_param_grid["lamb"][0],
            fit_intercept=True,
            normalize=True,
            copy_X=True
        )
        model.fit(X_train_np, Y_train_np, A=A_anchor)
        Y_test_pred = model.predict(X_test_np)
        losses['anchor_new'] = evaluate_mse_pearson(torch.tensor(Y_test_pred, device=Y_test.device), Y_test)

    mlp_hidden_units = 64
    epochs = 1000
    lr = 1e-3

    input_dim = X_test_approx.shape[1]
    output_dim = Y_train.shape[1] if Y_train.ndim > 1 else 1

    model = MLP(
        input_dim=input_dim,
        hidden_dims=[mlp_hidden_units, mlp_hidden_units],
        output_dim=output_dim,
        activation=nn.ReLU()
    ).to(X_train.device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    best_val_loss = float("inf")
    best_model_state = None

    pbar = trange(epochs, desc="MLP Training", leave=False)
    for epoch in pbar:
        model.train()
        optimizer.zero_grad()
        preds_train = model(X_train).squeeze(-1)
        loss = criterion(preds_train, Y_train)
        loss.backward()
        optimizer.step()

        # Optional: validate
        if X_val is not None:
            model.eval()
            with torch.no_grad():
                val_preds = model(X_val).squeeze(-1)
                val_loss = criterion(val_preds, Y_val).item()
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_model_state = model.state_dict()

            pbar.set_postfix({
                    "train_loss": f"{loss.detach():.4f}",
                    "val_loss": f"{val_loss:.4f}"
            })
        else:
            pbar.set_postfix({
                "train_loss": f"{loss.detach():.4f}"
            })
    # Load best model if validation was used
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    # Predict on test set
    model.eval()
    with torch.no_grad():
        mlp_preds = model(X_test_approx).squeeze(-1)

    losses['mlp'] = evaluate_mse_pearson(mlp_preds, Y_test)

    #  ---------- IRM ----------
     #Treat each row as an environment
    envs = [(X_train[i:i + 1], Y_train[i:i + 1]) for i in range(len(X_train))]
    # Initialize IRM
    irm_model = InvariantRiskMinimization(
        environments=envs,
        x_val=X_val,
        y_val=Y_val,
        args={
            "n_iterations": 25000,
            "lr": 5e-4,
            "verbose": False
        },
    )
    # Get learned solution (w) as a vector
    irm_solution = irm_model.solution()  # shape: (d,)
    Y_train_mean = Y_train.mean()
    # Predict on test set
    Y_test_pred = (X_test_approx @ irm_solution).squeeze(-1) + Y_train_mean
    losses['irm'] = evaluate_mse_pearson(Y_test_pred, Y_test)


    def train_and_evaluate_attention_wrapper(comb):
        return train_and_evaluate_mlp_attention(
            X_train,
            X_train,
            Y_train,
            X_test_approx,
            Y_test,
            X_val,
            Y_val,
            device='cuda:3',
            **comb,
        )


    # --- Attention Regression ---
    if attention_param_grid is not None and X_val is not None:
        keys, values = zip(*attention_param_grid.items())
        combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
        best_val_loss = float("inf")
        best_test_results = float("inf")

        for comb in tqdm(combinations, desc="MLP Attention CV", leave=False):
            loss = train_and_evaluate_mlp_attention(
            X_train,
            X_train,
            Y_train,
            X_test_approx,
            Y_test,
            X_val,
            Y_val,
            device='cuda:3',
            **comb,
        )
                
            if loss[1]['mse'] < best_val_loss:
                best_val_loss = loss[1]['mse']
                best_test_results = loss[0]

        losses['attention'] = best_test_results
    else:
        losses['attention'] = train_and_evaluate_mlp_attention(
            X_train,
            X_train,
            Y_train,
            X_test_approx,
            Y_test,
            X_val,
            Y_val,
            device='cuda:3')[0]

    return losses, W_ridge, W_sparse, W_lasso
