import os
import logging
import numpy as np
import sys

import torch
import torch.nn as nn
import geoopt
from torch.utils.data import DataLoader, TensorDataset
from transformers import set_seed
import datasets
import transformers.utils.logging as trfl


class ManifoldProjectionLayer(nn.Module):
    """A custom layer to project a tensor onto a linear subspace spanned by a d-frame from a Stiefel manifold."""

    def __init__(self, n: int, d: int):
        """
        Args:
            n: Dimension of the observable space (the one we want to project *from*).
            d: Dimension of the linear subspace (the one we want to project *onto*).
            init: An optional initial setting of the manifold parameter
            stiefel_manifold: Whether the projection layer should be on the Stiefel manifold.
                              This should be left True for almost all use cases, except for unconstrained optimisation
                              or investigating the effext of the constrained optimisation versus the unconstrained one.
        """
        super().__init__()
        self.A = geoopt.ManifoldParameter(geoopt.Stiefel().random(n, d))
        # self.A is a `n x d` matrix
        # Notice that due to the shape of the matrix, we have to multiply with its transpose to project a single
        # data point onto the subspace.

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x @ self.A


class AffineMapLayer(nn.Module):
    """A custom layer implementing an affine map."""

    def __init__(self, n: int, d: int):
        """
        Args:
            n: Dimension of the input space.
            d: Dimension of the output space.
        """
        super().__init__()
        self.W = geoopt.ManifoldParameter(torch.randn(n, d))  # Affine matrix
        self.b = geoopt.ManifoldParameter(torch.randn(d))  # Affine bias

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x @ self.W + self.b


class DoubleAffineLayer(nn.Module):
    """A custom layer implementing an affine map."""

    def __init__(self, n1: int, d1: int, n2: int, d2: int):
        """
        Args:
            n1: Dimension of the input space 1.
            d1: Dimension of the output space 1.
            n2: Dimension of the input space 2.
            d2: Dimension of the output space 2.
        """
        super().__init__()
        # mapping X
        self.W1 = geoopt.ManifoldParameter(torch.randn(n1, d1))  # Affine matrix
        self.b1 = geoopt.ManifoldParameter(torch.randn(d1))  # Affine bias

        # mapping Y
        self.W2 = geoopt.ManifoldParameter(
            torch.randn(n2, d2)
        )  # assume both live in the same input space
        self.b2 = geoopt.ManifoldParameter(torch.randn(d2))  # Affine bias

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        # return (x @ self.W1 + self.b1) - (y @ self.W2 + self.b2)
        return x @ self.W1 + self.b1, y @ self.W2 + self.b2


class OrthogonalProjectionAffineModel(nn.Module):
    """Model composed of an orthogonal projection followed by an affine map."""

    def __init__(self, n: int, d: int):
        """
        Args:
            n: Dimension of the observable space (the one we want to project *from*).
            d: Dimension of the linear subspace (the one we want to project *onto*).
            k: Dimension of the input and output space for the affine map.
        """
        super().__init__()
        self.proj = ManifoldProjectionLayer(n, d)
        self.affine = AffineMapLayer(d, d)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_proj = self.proj(x)
        return self.affine(x_proj)


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
eval_mode = "all"  # 'all' or 'last

logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)


def training_loop(
    train_x,
    train_y,
    eval_x,
    eval_y,
    model,
    lr: int = 2,
    num_epochs: int = 200,
    do_eval: bool = False,
):

    # Define optimizer
    optimizer = geoopt.optim.RiemannianSGD(
        model.parameters(), lr=lr
    )  # LR IS HYPERPARAMETER

    # Define loss function
    criterion = nn.MSELoss(reduction="none")

    # Convert data into PyTorch tensors and create DataLoader
    train_dataset = TensorDataset(torch.Tensor(train_x), torch.Tensor(train_y))
    eval_dataset = TensorDataset(torch.Tensor(eval_x), torch.Tensor(eval_y))
    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=32, shuffle=False)

    # Training loop
    for epoch in range(num_epochs):
        epoch_loss = 0.0  # Initialize epoch loss
        epoch_max_loss = 0.0
        model.train()
        for batch_x, batch_y in train_dataloader:
            optimizer.zero_grad()
            output = model(batch_x)
            loss_vec = criterion(output, batch_y)
            loss = loss_vec.mean()
            max_loss = loss_vec.mean(axis=1).max()
            max_loss.backward()
            # loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * len(batch_x)  # Aggregate loss over the epoch
            epoch_max_loss = max(max_loss, epoch_max_loss)  # max aggregation
        epoch_loss /= len(train_dataset)  # Calculate the mean loss over the epoch
        # if epoch % 10 == 9:
        logger.info(
            f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Max-Loss: {epoch_max_loss:.4f}"
        )

        if do_eval:
            # Evaluation
            if epoch % 10 == 9:
                model.eval()
                with torch.no_grad():
                    val_loss = 0.0
                    epoch_max_val_loss = 0.0
                    for val_batch_x, val_batch_y in eval_dataloader:
                        output = model(val_batch_x)
                        loss_vec = criterion(output, val_batch_y)
                        loss = loss_vec.mean()
                        max_loss = loss_vec.mean(axis=1).max()
                        val_loss += loss.item() * len(
                            val_batch_x
                        )  # Aggregate loss over the epoch
                        epoch_max_val_loss = max(
                            max_loss, epoch_max_val_loss
                        )  # max aggregation
                    val_loss /= len(
                        eval_dataset
                    )  # Calculate the mean loss over the epoch
                logger.info(
                    f"Validation Loss: {val_loss:.4f}, max-val-loss: {epoch_max_val_loss:.4f}"
                )

    return (epoch_loss, epoch_max_loss)


def main():
    # log_level = training_args.get_process_log_level()
    log_level = logging.INFO
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    trfl.set_verbosity(log_level)
    trfl.enable_default_handler()
    trfl.enable_explicit_format()

    # global vars
    set_seed(42)
    ROOT_PATH = "representations-large"

    # define model
    model = AffineMapLayer(768, 768)  # magic numbers
    model.to(DEVICE)
    # train_losses_global = []
    for task in ["sst2", "mrpc"]:
        train_losses_task = []
        for seed in range(25):
            train_losses_layers = []
            train_x = torch.load(
                os.path.join(ROOT_PATH, f"seed-{seed}-task-{task}-train"),
                map_location=DEVICE,
            )
            eval_x = torch.load(
                os.path.join(ROOT_PATH, f"seed-{seed}-task-{task}-eval"),
                map_location=DEVICE,
            )
            for seed_y in [s for s in range(25) if s != seed]:
                train_y = torch.load(
                    os.path.join(ROOT_PATH, f"seed-{seed_y}-task-{task}-train"),
                    map_location=DEVICE,
                )
                eval_y = torch.load(
                    os.path.join(ROOT_PATH, f"seed-{seed_y}-task-{task}-eval"),
                    map_location=DEVICE,
                )
                if eval_mode == "last":
                    train_x = train_x.select(0, -1)
                    train_y = train_y.select(0, -1)
                    train_losses_layers.append(
                        training_loop(
                            train_x=train_x,
                            train_y=train_y,
                            eval_x=eval_x,
                            eval_y=eval_y,
                            model=model,
                            lr=2,
                            num_epochs=100,
                        )
                    )
                    train_losses_task.append(train_losses_layers)
                elif eval_mode == "all":
                    for i in range(train_x.size(0)):
                        train_losses_layers.append(
                            training_loop(
                                train_x=train_x.select(0, i),
                                train_y=train_y.select(0, i),
                                eval_x=eval_x,
                                eval_y=eval_y,
                                model=model,
                                lr=2,
                                num_epochs=100,
                            )
                        )
                    train_losses_task.append(train_losses_layers)
                else:
                    raise ValueError(
                        f"eval_mode must be 'all' or 'last', got {eval_mode=}"
                    )
        # train_losses_global.append(train_losses_tasks)
        with open(os.path.join(ROOT_PATH, f"metrics-{task}.npy"), "wb") as f:
            np.save(f, np.array(train_losses_task))


if __name__ == "__main__":
    main()
