from collections import defaultdict
import logging
import numpy as np
import os
import sys

import torch
import torch.nn as nn
import geoopt

from torch.utils.data import DataLoader, TensorDataset
from transformers import set_seed
import datasets
import transformers.utils.logging as trfl
from multiprocessing import Pool


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
eval_mode = "all"  # 'all' or 'last

logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)


# Define the model
class LinearSoftmax(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearSoftmax, self).__init__()
        self.linear = nn.Linear(input_size, output_size, bias=True)

    def forward(self, x):
        x = self.linear(x)
        return torch.softmax(x, dim=1)


def convert_class_indicies_to_one_hot(
    class_indices: torch.tensor, n_classes: int
) -> torch.tensor:
    one_hot_labels = torch.zeros((class_indices.shape[0], n_classes))
    one_hot_labels[range(class_indices.shape[0]), class_indices.to(int)] = 1
    return one_hot_labels


def train_model(task_name, seed_x):
    # Define some dummy data
    batch_size = 64
    num_epochs = 20

    ROOT_PATH = "representations-large"

    logger.info(f"running on {DEVICE}")
    lr_list = ["0.01", "0.1", "0.2", "0.4"]
    logger.info(f"task name: {task_name}")
    if task_name == "mnli":
        output_size = 3
    else:
        output_size = 2
    logger.info(f"seed: {seed_x}")
    train_accuracies_global = []
    seed_accuracies = np.load(
        os.path.join(
            ROOT_PATH,
            f"final-layer-train-prediction-accuracies-{task_name}-{seed_x}.npy",
        )
    )
    best_lr = lr_list[seed_accuracies.argmax()]
    # for layer in range(13):
    layer = 12
    logger.info(f"layer: {layer}")
    x_train = torch.load(
        os.path.join(ROOT_PATH, f"seed-{seed_x}-task-{task_name}-train"),
        map_location=DEVICE,
    ).select(0, layer)

    for seed_y in [s for s in range(25) if s != seed_x]:
        # fit the data to the best performing other predictions
        y_train = torch.load(
            os.path.join(
                ROOT_PATH,
                f"final-layer-train-predictions-{task_name}-{seed_y}-lr" + best_lr,
            )
        )

        # Define some dummy data
        input_size = x_train.shape[-1]

        # Convert data into PyTorch tensors and create DataLoader
        dataset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        # Define loss function and optimizer
        train_accuracies = []
        for lr in [0.001, 0.01, 0.02]:
            # Initialize the model
            model = LinearSoftmax(input_size, output_size)
            model.to(DEVICE)
            criterion = nn.MSELoss(reduction="none")

            logger.info(f"lr: {lr}")
            optimizer = geoopt.optim.RiemannianSGD(model.parameters(), lr=lr)
            # Training loop
            for epoch in range(num_epochs):
                epoch_loss = 0.0  # Initialize epoch loss
                epoch_max_loss = 0.0
                model.train()
                # # Forward pass

                for batch_x, batch_y in dataloader:
                    optimizer.zero_grad()
                    output = model(batch_x)
                    loss_vec = criterion(output, batch_y)
                    loss = loss_vec.mean()
                    max_loss = loss_vec.mean(axis=1).max()
                    # loss.backward()
                    max_loss.backward()
                    optimizer.step()
                    epoch_loss += loss.item() * len(
                        batch_x
                    )  # Aggregate loss over the epoch
                    epoch_max_loss = max(max_loss, epoch_max_loss)  # max aggregation
                epoch_loss /= len(dataset)  # Calculate the mean loss over the epoch
                # if epoch % 10 == 9:
                logger.info(
                    f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Max-Loss: {epoch_max_loss:.4f}"
                )

            model.eval()
            with torch.no_grad():
                train_preds = model(x_train)
                train_prediction_accuracy = (
                    (model(x_train).argmax(axis=1) == y_train.argmax(axis=1))
                    .detach()
                    .numpy()
                    .mean()
                )
                # Convert epoch_loss and epoch_max_loss to PyTorch tensors before appending
                epoch_loss_tensor = torch.tensor(epoch_loss, dtype=torch.float32)
                epoch_max_loss_tensor = torch.tensor(
                    epoch_max_loss, dtype=torch.float32
                )
                # Append accuracy, train mse, and max mse after detaching
                train_accuracies.append(
                    [
                        train_prediction_accuracy,
                        epoch_loss_tensor.cpu().detach().numpy(),
                        epoch_max_loss_tensor.cpu().detach().numpy(),
                    ]
                )
                # # append accuracy, train mse, and max mse
                # train_accuracies.append([train_prediction_accuracy, epoch_loss, epoch_max_loss])
            # save predictions
            torch.save(
                train_preds,
                os.path.join(
                    ROOT_PATH,
                    f"v3-extrinsic-maxloss-predictions-{task_name}-seedx-{seed_x}-seedy-{seed_y}-lr{lr}",
                ),
            )
            torch.save(
                model.state_dict(),
                os.path.join(
                    ROOT_PATH,
                    f"v3-extrinsic-maxloss-model-{task_name}-seedx-{seed_x}-seedy-{seed_y}-lr{lr}",
                ),
            )

        train_accuracies_global.append(train_accuracies)

    # save accuracies
    with open(
        os.path.join(
            ROOT_PATH,
            f"v3-extrinsic-maxloss-accuracies-{task_name}-small-{seed_x}.npy",
        ),
        "wb",
    ) as f:
        np.save(f, np.array(train_accuracies_global))


if __name__ == "__main__":
    # log_level = training_args.get_process_log_level()
    log_level = logging.INFO
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    trfl.set_verbosity(log_level)
    trfl.enable_default_handler()
    trfl.enable_explicit_format()
    set_seed(42)
    # tasks = ["sst2", "mnli", "mrpc"]
    tasks = ["sst2", "mrpc"]
    seeds = range(25)
    with Pool(processes=16) as pool:
        pool.starmap(train_model, [(task, seed) for task in tasks for seed in seeds])
