from collections import defaultdict
import logging
import numpy as np
import os
import sys

import torch
import torch.nn as nn
import geoopt
import torch.nn.functional as F

from torch.utils.data import DataLoader, TensorDataset
from transformers import set_seed
import datasets
import transformers.utils.logging as trfl
from multiprocessing import Pool


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
eval_mode = "all"  # 'all' or 'last

logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)


# Define the model
class LinearSoftmax(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearSoftmax, self).__init__()
        self.linear = nn.Linear(input_size, output_size, bias=True)

    def forward(self, x):
        x = self.linear(x)
        return torch.softmax(x, dim=1)


def convert_class_indicies_to_one_hot(
    class_indices: torch.tensor, n_classes: int
) -> torch.tensor:
    one_hot_labels = torch.zeros((class_indices.shape[0], n_classes))
    one_hot_labels[range(class_indices.shape[0]), class_indices.to(int)] = 1
    return one_hot_labels


def train_model(task_name, seed_x, seed_y):
    # Define some dummy data
    batch_size = 64
    num_epochs = 20

    ROOT_PATH = "representations-large"

    logger.info(f"running on {DEVICE}")
    lr_list = ["0.01", "0.1", "0.2", "0.4"]
    logger.info(f"task name: {task_name}")
    if task_name == "mnli":
        output_size = 3
    else:
        output_size = 2
    logger.info(f"seed: {seed_x}")
    train_accuracies_global = []
    seed_accuracies = np.load(
        os.path.join(
            ROOT_PATH,
            f"final-layer-train-prediction-accuracies-{task_name}-{seed_x}.npy",
        )
    )
    best_lr = lr_list[seed_accuracies.argmax()]
    # for layer in range(13):
    layer = 12
    logger.info(f"layer: {layer}")
    x_train = torch.load(
        os.path.join(ROOT_PATH, f"seed-{seed_x}-task-{task_name}-train"),
        map_location=DEVICE,
    ).select(0, layer)

    # for seed_y in [s for s in range(25) if s != seed_x]:
    # fit the data to the best performing other predictions
    # y_train = torch.load(
    #     os.path.join(ROOT_PATH, f"final-layer-train-predictions-{task_name}-{seed_y}-lr" + best_lr)
    # )
    y_train = torch.load(
        os.path.join(ROOT_PATH, f"seed-{seed_y}-task-{task_name}-train"),
        map_location=DEVICE,
    ).select(0, layer)

    # Define some dummy data
    input_size = x_train.shape[-1]
    train_accuracies = []

    nseeds = 1000
    for _s in range(nseeds):
        # get random probs with seed i
        from sklearn.linear_model import LinearRegression  # try with a random model

        torch.manual_seed(_s)
        random_probs = torch.nn.functional.softmax(
            torch.randn(y_train.shape[0], output_size)
        )
        lr = LinearRegression(fit_intercept=True).fit(y_train, random_probs)
        y_train_probs = torch.tensor(lr.predict(y_train))

        # Convert data into PyTorch tensors and create DataLoader
        dataset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train_probs))
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        # Define loss function and optimizer
        lr = 0.001
        # Initialize the model
        model = LinearSoftmax(input_size, output_size)
        model.to(DEVICE)
        criterion = nn.MSELoss(reduction="none")

        logger.info(f"lr: {lr}")
        optimizer = geoopt.optim.RiemannianSGD(model.parameters(), lr=lr)
        # Training loop
        for epoch in range(num_epochs):
            epoch_loss = 0.0  # Initialize epoch loss
            epoch_max_loss = 0.0
            model.train()
            # # Forward pass

            for batch_x, batch_y in dataloader:
                optimizer.zero_grad()
                output = model(batch_x)
                loss_vec = criterion(output, batch_y)
                loss = loss_vec.mean()
                max_loss = loss_vec.mean(axis=1).max()
                # loss.backward()
                max_loss.backward()
                optimizer.step()
                epoch_loss += loss.item() * len(
                    batch_x
                )  # Aggregate loss over the epoch
                epoch_max_loss = max(max_loss, epoch_max_loss)  # max aggregation
            epoch_loss /= len(dataset)  # Calculate the mean loss over the epoch
            # if epoch % 10 == 9:
            logger.info(
                f"seed: {_s}, Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Max-Loss: {epoch_max_loss:.4f}"
            )

        model.eval()
        with torch.no_grad():
            train_preds = model(x_train)
            logger.info(f"shapes: {train_preds.shape}, {y_train_probs.shape}")
            train_prediction_accuracy = (
                (train_preds.argmax(axis=1) == y_train_probs.argmax(axis=1))
                .detach()
                .numpy()
                .mean()
            )
            # Convert epoch_loss and epoch_max_loss to PyTorch tensors before appending
            epoch_loss_tensor = torch.tensor(epoch_loss, dtype=torch.float32)
            epoch_max_loss_tensor = torch.tensor(epoch_max_loss, dtype=torch.float32)
            # Append accuracy, train mse, and max mse after detaching
            train_accuracies.append(
                [
                    train_prediction_accuracy,
                    epoch_loss_tensor.cpu().detach().numpy(),
                    epoch_max_loss_tensor.cpu().detach().numpy(),
                ]
            )
            # # append accuracy, train mse, and max mse
            # train_accuracies.append([train_prediction_accuracy, epoch_loss, epoch_max_loss])

    # train_accuracies_global.append(train_accuracies)

    # save accuracies
    with open(
        os.path.join(
            ROOT_PATH,
            f"v2_extrinsic-equivalence-maxloss-accuracies-{task_name}-{seed_x}-{seed_y}-nseeds-{nseeds}.npy",
        ),
        "wb",
    ) as f:
        np.save(f, np.array(train_accuracies))


if __name__ == "__main__":
    # log_level = training_args.get_process_log_level()
    log_level = logging.INFO
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    trfl.set_verbosity(log_level)
    trfl.enable_default_handler()
    trfl.enable_explicit_format()
    set_seed(42)
    # tasks = ["sst2", "mnli", "mrpc"]
    # tasks = ["mrpc", "sst2"]
    tasks = ["sst2"]
    seeds = range(25)
    with Pool(processes=32) as pool:
        pool.starmap(
            train_model,
            [
                (task, seed, seedy)
                for task in tasks
                for seed in seeds
                for seedy in [s for s in seeds if s != seed]
            ],
        )
