import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm  # Import the progress bar library

from src.evaluation import evaluate_mse_pearson
from src.regression import torchOLS


class MLP(nn.Module):
    """
    A commonly used MLP class in PyTorch. You can customize:
      - The number of hidden layers by passing a list in `hidden_dims`.
      - The activation function (default is ReLU).
      - Dropout/BatchNorm or other layers (by inserting them in the sequence).

    Example usage:
        model = MLP(
            input_dim=10,
            hidden_dims=[32, 64, 32],
            output_dim=1,
            activation=nn.ReLU()
        )
        x = torch.randn(5, 10)
        y = model(x)
        print(y.shape)  # => torch.Size([5, 1])
    """

    def __init__(self,
                 input_dim: int,
                 hidden_dims: list,
                 output_dim: int,
                 activation=nn.ReLU()):
        super(MLP, self).__init__()

        layers = []
        in_features = input_dim

        # Build each hidden layer (Linear -> Activation)
        for hdim in hidden_dims:
            layers.append(nn.Linear(in_features, hdim))
            layers.append(activation)
            in_features = hdim

        # Final layer (no activation here, unless you specifically want one)
        layers.append(nn.Linear(in_features, output_dim))

        # Wrap it all in a Sequential
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)


def train_mlp_model(
        E_train,
        X_train,
        E_test,
        X_test,
        mlp_hidden_units,
        E_val=None,
        X_val=None,
        lr=0.001,
        epochs=1000,
        seed=None,
        device='cpu'  # Default to CPU
):
    # Set the seed for reproducibility
    if seed is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Move input data to the specified device
    E_train = E_train.to(device)
    X_train = X_train.to(device)
    E_test = E_test.to(device)
    X_test = X_test.to(device)
    if E_val is not None:
        E_val = E_val.to(device)
        X_val = X_val.to(device)

    # Define model and move to device
    feature_dim = E_train.shape[-1]
    output_dim = X_train.shape[-1]
    model = MLP(
        input_dim=feature_dim,
        hidden_dims=[mlp_hidden_units, mlp_hidden_units],
        output_dim=output_dim,
        activation=torch.nn.ReLU()
    ).to(device)  # Move model to device

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    X_train_loss = []
    X_test_loss = []
    X_val_loss = []

    best_val_loss = float('inf')
    best_test_X_approx = None
    val_loss = None

    # Training loop
    with tqdm(total=epochs, desc="Training Progress", unit="epoch") as pbar:

        for epoch in range(epochs):
            optimizer.zero_grad()
            model.train()

            # Forward pass
            X_train_approx = model(E_train)
            loss = criterion(X_train, X_train_approx)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Track training loss
            X_train_loss.append(loss.detach().cpu().numpy())  # Move loss back to CPU for logging

            # Evaluation step
            model.eval()
            with torch.no_grad():
                X_test_approx = model(E_test)
                test_loss = criterion(X_test, X_test_approx)
                X_test_loss.append(test_loss.detach().cpu().numpy())  # Move loss back to CPU for logging

                if E_val is not None:
                    X_val_approx = model(E_val)
                    val_loss = criterion(X_val, X_val_approx)
                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        best_test_X_approx = X_test_approx.clone().detach()

            # Update the progress bar
            pbar.set_postfix({"Train Loss": f"{loss.item():.4f}",
                              "Val Loss": f"{val_loss.item():.4f}" if val_loss is not None else "None",
                              "Test Loss": f"{test_loss.item():.4f}"})
            pbar.update(1)

    if E_val is not None:
        return best_test_X_approx, X_train_loss, X_test_loss, X_val_loss
    else:
        return X_test_approx, X_train_loss, X_test_loss, X_val_loss


def train_and_evaluate_mlp_attention(
        E_train,
        X_train,
        Y_train,
        E_test,
        Y_test,
        E_val=None,
        Y_val=None,
        hidden_dim = 32,
        lambda_reg=1e-6,
        lr=0.004,
        epochs=100,
        device="cpu"
):
    """
    Train and evaluate an MLP followed by an attention mechanism.
    """

    # Move input data to the specified device
    E_train = E_train.to(device)
    X_train = X_train.to(device)
    E_test = E_test.to(device)
    Y_train = Y_train.to(device)
    Y_test = Y_test.to(device)
    if E_val is not None:
        E_val = E_val.to(device)
        Y_val = Y_val.to(device)

    # Define model and move to device
    input_dim = E_train.shape[-1]
    # output_dim = X_train.shape[-1]
    output_dim = hidden_dim
    embedding_mlp = MLP(
        input_dim=input_dim,
        hidden_dims=[hidden_dim],
        output_dim=output_dim,
        activation=torch.nn.ReLU()
    ).to(device)  # Move model to device

    # Define Layer Normalization layers
    layer_norm_input = nn.LayerNorm(input_dim).to(device)
    layer_norm_attention = nn.LayerNorm(output_dim).to(device)

    attn_model = nn.MultiheadAttention(embed_dim=output_dim, num_heads=1, batch_first=False).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(list(embedding_mlp.parameters()) + list(attn_model.parameters()), lr=lr)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.1 * lr)

    # Binary attention mask setup
    attn_mask = torch.full((len(X_train), len(X_train)), False, device=device)
    attn_mask.fill_diagonal_(True)

    train_loss = []
    test_loss = []
    val_loss = []

    best_val_loss = float('inf')
    best_test_result = None
    losses = {}

    # Training loop
    with tqdm(total=epochs, desc="Training Progress", unit="epoch") as pbar:
        for epoch in range(epochs):
            # Training phase
            embedding_mlp.train()
            attn_model.train()
            optimizer.zero_grad()

            # Normalize inputs before the MLP
            E_train = layer_norm_input(E_train)
            # Transform embeddings with MLP
            E_train_mlp = embedding_mlp(E_train)

            # Normalize MLP output before attention
            E_train_mlp = layer_norm_attention(E_train_mlp)

            # Apply attention
            attn_output, attn_weights = attn_model(query=E_train_mlp, key=E_train_mlp, value=E_train_mlp,
                                                   attn_mask=None)

            # Use torchOLS_with_intercept for prediction
            Y_hat_batch = torchOLS(attn_output, Y_train, attn_output, Y_train, attn_weights ** 2,
                                                  lambda_reg=lambda_reg)

            # Compute losses
            loss_main = criterion(Y_train, Y_hat_batch)
            # loss_reconstruction = criterion(attn_output, X_train)
            loss_combined = loss_main # + lambda_rec * loss_reconstruction

            # Backpropagation
            loss_combined.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(embedding_mlp.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(attn_model.parameters(), max_norm=1.0)

            # Update weights
            optimizer.step()

            # Update scheduler
            scheduler.step()

            train_loss.append(loss_combined.item())

            # Evaluation phase
            embedding_mlp.eval()
            attn_model.eval()
            with torch.no_grad():
                if E_val is not None:
                    E_val_norm = layer_norm_input(E_val)
                    E_val_mlp = embedding_mlp(E_val_norm)
                    E_val_mlp = layer_norm_attention(E_val_mlp)

                    attn_output_train, attn_weights_train = attn_model(query=E_train_mlp, key=E_train_mlp,
                                                                       value=E_train_mlp,
                                                                       attn_mask=attn_mask)
                    attn_output_val, attn_weights_val = attn_model(query=E_val_mlp, key=E_train_mlp,
                                                                   value=E_train_mlp)
                    Y_hat_val = torchOLS(attn_output_train, Y_train, attn_output_val, Y_val,
                                                        attn_weights_val ** 2, lambda_reg=lambda_reg)
                    loss_val = criterion(Y_val, Y_hat_val)
                    val_loss.append(loss_val.item())

                # Normalize inputs before the MLP
                E_train = layer_norm_input(E_train)
                E_test = layer_norm_input(E_test)

                # Transform test embeddings with MLP
                E_train_mlp = embedding_mlp(E_train)
                E_test_mlp = embedding_mlp(E_test)

                # Normalize MLP output before attention
                E_train_mlp = layer_norm_attention(E_train_mlp)
                E_test_mlp = layer_norm_attention(E_test_mlp)

                # Apply attention
                attn_output_train, attn_weights_train = attn_model(query=E_train_mlp, key=E_train_mlp,
                                                                   value=E_train_mlp,
                                                                   attn_mask=attn_mask)
                attn_output_test, attn_weights_test = attn_model(query=E_test_mlp, key=E_train_mlp, value=E_train_mlp)

                # Use torchOLS_with_intercept for testing
                Y_hat_test = torchOLS(attn_output_train, Y_train, attn_output_test, Y_test,
                                                     attn_weights_test ** 2, lambda_reg=lambda_reg)

                # Compute test loss
                loss_main_test = criterion(Y_test, Y_hat_test)
                test_loss.append(loss_main_test.item())

                if E_val is not None:
                    if loss_val.item() < best_val_loss:
                        best_val_loss = loss_val.item()
                        best_val_result = evaluate_mse_pearson(Y_hat_test, Y_test)
                        best_test_result = evaluate_mse_pearson(Y_hat_test, Y_test)

            # Update the progress bar
            pbar.set_postfix({
                "Train Loss": f"{loss_combined.item():.4f}",
                "Train Loss Y pred": f"{loss_main.item():.4f}",
                # "Reconstruction Loss": f"{loss_reconstruction.item():.4f}",
                "Val Loss Y pred": f"{loss_val.item():.4f}" if loss_val is not None else "None",
                "Test Loss Y pred": f"{loss_main_test.item():.4f}"
            })
            pbar.update(1)

        if E_val is not None:
            losses = best_test_result
            return losses, best_val_result
        else:
            losses = evaluate_mse_pearson(Y_hat_test, Y_test)

        return losses
    