from collections import defaultdict
import logging
import numpy as np
import os
import sys

import torch
import torch.nn as nn
import geoopt

from torch.utils.data import DataLoader, TensorDataset
from transformers import set_seed
import datasets
import transformers.utils.logging as trfl

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
eval_mode = "all"  # 'all' or 'last
ROOT_PATH = "representations-large"


logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)


# Define the model
class LinearSoftmax(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearSoftmax, self).__init__()
        self.linear = nn.Linear(input_size, output_size, bias=True)

    def forward(self, x):
        x = self.linear(x)
        return torch.softmax(x, dim=1)


def convert_class_indicies_to_one_hot(
    class_indices: torch.tensor, n_classes: int
) -> torch.tensor:
    one_hot_labels = torch.zeros((class_indices.shape[0], n_classes))
    one_hot_labels[range(class_indices.shape[0]), class_indices.to(int)] = 1
    return one_hot_labels


def main():
    # log_level = training_args.get_process_log_level()
    log_level = logging.INFO
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    trfl.set_verbosity(log_level)
    trfl.enable_default_handler()
    trfl.enable_explicit_format()
    set_seed(42)
    # Define some dummy data
    batch_size = 128
    num_epochs = 20
    layer = 12

    logger.info(f"running on {DEVICE}")

    for task_name in ["sst2", "mrpc"]:
        logger.info(f"task name: {task_name}")
        if task_name == "mnli":
            output_size = 3
        else:
            output_size = 2
        for seed in range(25):
            logger.info(f"seed: {seed}")
            train_accuracies_global = []
            accuracies_global = []
            # for layer in range(13):
            logger.info(f"layer: {layer}")
            x_train = torch.load(
                os.path.join(ROOT_PATH, f"seed-{seed}-task-{task_name}-train"),
                map_location=DEVICE,
            ).select(0, layer)
            x_eval = torch.load(
                os.path.join(ROOT_PATH, f"seed-{seed}-task-{task_name}-eval"),
                map_location=DEVICE,
            ).select(0, layer)
            y_train = torch.tensor(
                torch.load(os.path.join(ROOT_PATH, f"{task_name}_labels_train"))
            ).to(DEVICE)
            y_eval = torch.tensor(
                torch.load(os.path.join(ROOT_PATH, f"{task_name}_labels_eval"))
            ).to(DEVICE)

            y_train_oh = convert_class_indicies_to_one_hot(y_train, output_size)
            y_eval_oh = convert_class_indicies_to_one_hot(y_eval, output_size)

            # Define some dummy data
            input_size = x_train.shape[-1]

            # Convert data into PyTorch tensors and create DataLoader
            dataset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

            # Initialize the model
            model = LinearSoftmax(input_size, output_size)
            model.to(DEVICE)

            # Define loss function and optimizer
            criterion = nn.CrossEntropyLoss()
            train_accuracies = []
            accuracies = []
            for lr in [0.01, 0.1, 0.2, 0.4]:
                logger.info(f"lr: {lr}")
                optimizer = geoopt.optim.RiemannianSGD(model.parameters(), lr=lr)
                # Training loop
                for epoch in range(num_epochs):
                    model.train()
                    # Forward pass
                    for batch_x, batch_y in dataloader:
                        outputs = model(batch_x)
                        loss = criterion(outputs, batch_y)

                        # Backward pass and optimization
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    logger.info(
                        f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {loss.item():.4f}"
                    )

                    # Evaluation
                    model.eval()
                    with torch.no_grad():
                        val_outputs = model(x_eval)
                        val_loss = criterion(val_outputs, y_eval)

                    # Print loss
                    logger.info(
                        f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss.item():.4f}"
                    )

                model.eval()
                with torch.no_grad():
                    train_preds = model(x_train)
                    train_prediction_accuracy = (
                        (model(x_train).argmax(axis=1) == y_train_oh.argmax(axis=1))
                        .detach()
                        .numpy()
                        .mean()
                    )
                    eval_preds = model(x_eval)
                    prediction_accuracy = (
                        (model(x_eval).argmax(axis=1) == y_eval_oh.argmax(axis=1))
                        .detach()
                        .numpy()
                        .mean()
                    )
                    train_accuracies.append(train_prediction_accuracy)
                    accuracies.append(prediction_accuracy)
                    logger.info(prediction_accuracy)
                # save predictions
                torch.save(
                    eval_preds,
                    os.path.join(
                        ROOT_PATH,
                        f"final-layer-predictions-{task_name}-{seed}-lr{lr}",
                    ),
                )
                torch.save(
                    train_preds,
                    os.path.join(
                        ROOT_PATH,
                        f"final-layer-train-predictions-{task_name}-{seed}-lr{lr}",
                    ),
                )
                torch.save(
                    model.state_dict(),
                    os.path.join(
                        ROOT_PATH, f"final-layer-model-{task_name}-{seed}-lr{lr}"
                    ),
                )

            train_accuracies_global.append(train_accuracies)
            accuracies_global.append(accuracies)

            # save accuracies
            with open(
                os.path.join(
                    ROOT_PATH,
                    f"final-layer-train-prediction-accuracies-{task_name}-{seed}.npy",
                ),
                "wb",
            ) as f:
                np.save(f, np.array(train_accuracies_global))
            # save accuracies
            with open(
                os.path.join(
                    ROOT_PATH,
                    f"final-layer-prediction-accuracies-{task_name}-{seed}.npy",
                ),
                "wb",
            ) as f:
                np.save(f, np.array(accuracies_global))


if __name__ == "__main__":
    main()
