import torch
import argparse
from collections import OrderedDict
from data_loading_sym import PartialMNIST_AE_Dataloader, RotMNIST_AE_Dataloader
import torch.nn.functional as F
import numpy as np
import sys
import pandas as pd
import matplotlib.pyplot as plt
import os
from sklearn.manifold import TSNE
import wandb

# Configuration

parser = argparse.ArgumentParser()

# General
parser.add_argument("--model_ind", type=int, default=111)

# Datasets
parser.add_argument("--dataset", type=str, default="PartMNIST")
parser.add_argument("--customdata_train_path", type=str,  # Path to oriented train dataset
                    default="./datasets/mnist_all_rotation_normalized_float_train_valid.amat")
parser.add_argument("--customdata_test_path", type=str,  # Path to oriented test dataset
                    default="./datasets/mnist_all_rotation_normalized_float_test.amat")

# Type of evaluation
parser.add_argument("--linear_classifier", action='store_true', default=False)  # Supervised method evaluation
parser.add_argument("--k_means", action='store_true', default=False)  # Unsupervised method evaluation

config = parser.parse_args()

# Wandb integration
if config.wandb_key:
    wandb.login(key=config.wandb_key)
wandb.init(
    project="unsup-equiv",
    config=config,
    entity="ck-experimental",
    mode=config.wandb_mode
)

# Output folder
folder_name = "eval"+str(config.model_ind)
try:
    os.makedirs(folder_name)
except:
    pass
config.out_dir = folder_name+"/"

if config.dataset == "PartMNIST":
    print("Loading toy MNIST datasets.")
    train_dataloader = PartialMNIST_AE_Dataloader(config, train=True, test=False, shuffle=True)
    test_dataloader = PartialMNIST_AE_Dataloader(config, train=False, test=True, shuffle=True, no_val_split=True)

if config.dataset == "RotMNIST":
    print("Loading benchmark MNIST datasets (MNISTRot/MNIST).")
    train_dataloader = RotMNIST_AE_Dataloader(config, train=True, test=False, shuffle=True)
    test_dataloader = RotMNIST_AE_Dataloader(config, train=False, test=True, shuffle=True, no_val_split=True)

train_dataloader, val_dataloader = train_dataloader[0], train_dataloader[1]
test_dataloader = test_dataloader[0]

