import os
import time

import numpy as np
import torch
import torch.nn as nn
from torchvision.transforms import Pad, RandomOrder

from data.datasets import MNIST
from data.random_augment import *
from exemplary_contrastive_self_training import experiment
from networks.cnn import CNN
from networks.mlp import MLP
from networks.model import Network
from utility.schedulers import SigmoidScheduler
from utility.seed import random_seed


def main():

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_samples = 60000
    num_classes = 10
    shot = 2
    total_epochs = 200
    epoch_steps = 75
    class_batch_size = (num_samples // epoch_steps) // num_classes
    weight_loss_invariance = 1.0
    margin = 0.01
    threshold_scheduler = SigmoidScheduler(10 / total_epochs, 2 * (0.9 - 1 / num_classes), 0, 1, -(0.9 - 2 / num_classes))
    rate_confidence_scheduler = SigmoidScheduler(0.3, 0.99, -1, 2)
    rate_similarity_scheduler = SigmoidScheduler(0.3, 0.99, -1, 2)

    assert num_samples % epoch_steps == 0

    data_train, label_train, data_test, label_test = map(np.copy, MNIST("./datasets/MNIST/"))
    data_train = torch.from_numpy(data_train[:num_samples].astype(np.float32)) / 255
    label_train = torch.from_numpy(label_train[:num_samples].astype(np.int64))
    data_test = torch.from_numpy(data_test.astype(np.float32)) / 255
    label_test = torch.from_numpy(label_test.astype(np.int64))
    data_train = Pad(2)(data_train)
    data_test = Pad(2)(data_test)

    transform = RandomOrder([
        RandomRotation((-10, 10), 1.0),
        RandomTranslation((1, 1), 1.0),
        RandomScaling((0.9, 1.1), 1.0),
        RandomShear((-10, 10, -10, 10), 1.0),
    ])

    id = f"MNIST-{os.path.split(os.getcwd())[-1]}-{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"

    for seed in range(10):
        random_seed(seed)
        features = (4 * 4 * 128, 128)
        num_feature = features[-1]
        backbone = nn.Sequential(CNN(), MLP(features))
        head = nn.Linear(num_feature, num_classes)

        model = Network(backbone, head).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: (1.0 if (step < 0.6 * total_epochs * epoch_steps) else (0.1 + 0.9 * (1.0 - ((step - 0.6 * total_epochs * epoch_steps) / ((1.0 - 0.6) * total_epochs * epoch_steps))))))
        experiment(id, seed, model, optimizer, scheduler, data_train, label_train, data_test, label_test, transform, device, num_samples, num_classes, num_feature, shot, total_epochs, epoch_steps, class_batch_size, weight_loss_invariance, margin, threshold_scheduler, rate_confidence_scheduler, rate_similarity_scheduler)


if __name__ == "__main__":
    main()
