import os
import time

import numpy as np
import torch
import torch.nn as nn
from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Pad, RandomAutocontrast, RandomErasing, RandomOrder

from data.datasets import SVHN
from data.random_augment import *
from exemplary_contrastive_self_training import experiment
from networks.model import Network
from networks.wideresnet import WideResNet
from utility.schedulers import SigmoidScheduler
from utility.seed import random_seed


def main():

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_samples = 73200
    num_classes = 10
    shot = 25
    total_epochs = 200
    epoch_steps = 732
    class_batch_size = (num_samples // epoch_steps) // num_classes
    weight_loss_invariance = 1.0
    margin = 0.01
    threshold_scheduler = SigmoidScheduler(10 / total_epochs, 2 * (0.9 - 1 / num_classes), 0, 1, -(0.9 - 2 / num_classes))
    rate_confidence_scheduler = SigmoidScheduler(0.5, 0.99, -1, 4)
    rate_similarity_scheduler = SigmoidScheduler(0.5, 0.99, -1, 4)

    assert num_samples % epoch_steps == 0

    data_train, label_train, data_test, label_test = map(np.copy, SVHN("./datasets/SVHN/"))
    data_train = torch.from_numpy(data_train[:num_samples].astype(np.float32)) / 255
    label_train = torch.from_numpy(label_train[:num_samples].astype(np.int64)) % 10
    data_test = torch.from_numpy(data_test.astype(np.float32)) / 255
    label_test = torch.from_numpy(label_test.astype(np.int64)) % 10

    transform = Compose([
        Pad(8, padding_mode='reflect'),
        RandomAffine((-10, 10), 1.0, (1, 1), 1.0, (0.9, 1.1), 1.0, (-10, 10, -10, 10), 1.0, InterpolationMode.BILINEAR),
        CenterCrop(32),
        RandomOrder([
            RandomHue((-0.5, 0.5), 1.0),
            RandomBrightness((-0.25, 0.25), 0.5),
            RandomContrast((-0.5, 0.5), 0.5),
            RandomSharpness((-0.5, 0.5), 0.5),
            RandomSaturation((-1.0, 1.0), 0.5),
            RandomSolarize((0.0, 1.0), 0.2),
            RandomPosterize((4, 8), 0.2),
            RandomEqualize(0.2),
            RandomAutocontrast(0.2),
        ]),
        RandomErasing(1.0, (0.2, 0.2), (1, 1)),
    ])

    id = f"SVHN-{os.path.split(os.getcwd())[-1]}-{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"

    for seed in range(3):
        random_seed(seed)
        num_groups = 3
        num_blocks = 4
        factor_base = 64
        factor_widen = 2
        num_feature = factor_base * pow(2, num_groups - 1) * factor_widen
        backbone = nn.Sequential(WideResNet(num_groups, num_blocks, factor_base, factor_widen, 3, momentum=0.1, dropout=0.0, global_pool='max'))
        head = nn.Linear(num_feature, num_classes)

        model = Network(backbone, head).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: (1.0 if (step < 0.6 * total_epochs * epoch_steps) else (0.1 + 0.9 * (1.0 - ((step - 0.6 * total_epochs * epoch_steps) / ((1.0 - 0.6) * total_epochs * epoch_steps))))))
        experiment(id, seed, model, optimizer, scheduler, data_train, label_train, data_test, label_test, transform, device, num_samples, num_classes, num_feature, shot, total_epochs, epoch_steps, class_batch_size, weight_loss_invariance, margin, threshold_scheduler, rate_confidence_scheduler, rate_similarity_scheduler)


if __name__ == "__main__":
    main()
