import argparse
import os
import time
import numpy as np
import torch
from data_loading_sym import PartialMNIST_AE_Dataloader, RotMNIST_AE_Dataloader
import wandb
from utils import config_to_str
from torchvision import models
import pytorch_lightning as pl



# Configuration ---------------------------------------------------------------------------------------
parser = argparse.ArgumentParser()

# General
parser.add_argument("--model_ind", type=int, required=True)  # ID
parser.add_argument("--wandb_key", type=str, default="")
parser.add_argument('--seed', type=int, default=0)

# Dataset
parser.add_argument("--dataset_root", type=str,
                    default="/MNIST")
parser.add_argument("--dataset", type=str, default="PartMNIST")
parser.add_argument("--customdata_train_path", type=str,
                    default="./src/datasets/mnist60/invariant_dataset_train.pkl")
parser.add_argument("--customdata_test_path", type=str,
                    default="./src/datasets/mnist60/invariant_dataset_test.pkl")

# Output
parser.add_argument("--out_root", type=str,
                    default="/saves/")

# Net params
parser.add_argument("--batch_sz", type=int, default=100)  # Batch size
parser.add_argument("--epochs", type=int, default=3)


# Logging
parser.add_argument("--wandb_mode", type=str, default="online")
parser.add_argument("--log_every", type=int, default=100)


def main():
    config = parser.parse_args()

    # Set seed
    if config.seed == -1:
        config.seed = np.random.randint(0, 100000)
    pl.seed_everything(config.seed)

    # Setup ------------------------------------------------------------------------

    config.out_dir = config.out_root + str(config.model_ind) + "/"
    config.dataloader_batch_sz = int(config.batch_sz)
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    folder_name = f"exp_{timestamp}"
    os.makedirs("saves/"+folder_name)
    config.out_dir = "saves/"+folder_name+"/"

    if not os.path.exists(config.out_dir):
        os.makedirs(config.out_dir)

    print("Config: %s" % config_to_str(config))

    # Initialize wandb
    if config.wandb_key:
        wandb.login(key=config.wandb_key)
    wandb.init(
        project="unsup-equiv",
        config=config,
        entity="ck-experimental",
        mode=config.wandb_mode
    )

    if config.dataset == "PartMNIST":
        print("Loading custom Partial Rot MNIST datasets (.pkl files)")
        main_dataloader = PartialMNIST_AE_Dataloader(config, train=True, test=False, shuffle=True)
    if config.dataset == "RotMNIST":
        print("Loading RotMNIST or MNIST benchmarks (.amat files)")
        main_dataloader = RotMNIST_AE_Dataloader(config, train=True, test=False, shuffle=True)

    train_dataloader = main_dataloader[0]
    val_dataloader = main_dataloader[1]

    model = models.resnet18(weights=None)

    # Modify the first layer to accept grayscale images
    model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    # 10 output classes (for MNIST)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model = model.cuda()

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training
    best_val_loss = float('inf')
    best_model = None

    for epoch in range(config.epochs):  # you can change the number of epochs
        model.train()
        for x, label in train_dataloader:
            x = x.cuda().unsqueeze(1)

            label = label.long().cuda()

            optimizer.zero_grad()

            # Forward
            outputs = model(x)

            # Backprop
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()

        model.eval()
        total_loss = 0
        with torch.no_grad():
            for x, label in val_dataloader:
                x = x.cuda().unsqueeze(1)
                label = label.long().cuda()

                outputs = model(x)
                loss = criterion(outputs, label)
                total_loss += loss.item()

        avg_val_loss = total_loss / len(val_dataloader
                                        )
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model = model.state_dict()

    # Load the best model
    model.load_state_dict(best_model)

    if config.dataset == "PartMNIST":
        test_dataloader = PartialMNIST_AE_Dataloader(config, train=False, test=True, shuffle=True,
                                                     no_val_split=True)
    if config.dataset == "RotMNIST":
        test_dataloader = RotMNIST_AE_Dataloader(config, train=False, test=True, shuffle=True,
                                                 no_val_split=True)


    test_dataloader = test_dataloader[0]

    model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for x, label in test_dataloader:
            x = x.cuda().unsqueeze(1)
            label = label.long().cuda()

            outputs = model(x)

            # Get predicted class
            _, predicted = outputs.max(1)

            # Update
            total += label.size(0)
            correct += (predicted == label).sum().item()

    test_accuracy = 100 * correct / total
    print(f"Test Accuracy: {test_accuracy:.2f}%")


if __name__ == "__main__":
    main()


