import os
import time

import numpy as np
import torch
import torch.nn as nn
from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Pad, RandomAutocontrast, RandomErasing, RandomHorizontalFlip, RandomOrder, ToPILImage, ToTensor

from data.datasets import CIFAR
from data.random_augment import *
from exemplary_contrastive_self_training import experiment
from networks.model import Network
from networks.wideresnet import WideResNet
from utility.schedulers import ConstantScheduler, SigmoidScheduler
from utility.seed import random_seed


def main():

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_samples = 50000
    num_classes = 10
    shot = 25
    total_epochs = 600
    epoch_steps = 500
    class_batch_size = (num_samples // epoch_steps) // num_classes
    weight_loss_invariance = 1.0 if num_classes == 10 else 10.0
    margin = 0.01
    threshold_scheduler = SigmoidScheduler(10 / total_epochs, 2 * (0.9 - 1 / num_classes), 0, 1, -(0.9 - 2 / num_classes))
    rate_confidence_scheduler = SigmoidScheduler(0.3, 0.99, -1, 2)
    rate_similarity_scheduler = ConstantScheduler(1.0)

    assert num_samples % epoch_steps == 0

    data_train, label_train, data_test, label_test = map(np.copy, CIFAR("./datasets/CIFAR/", num_classes))
    data_train = torch.from_numpy(data_train[:num_samples].astype(np.float32)) / 255
    label_train = torch.from_numpy(label_train[:num_samples].astype(np.int64))
    data_test = torch.from_numpy(data_test.astype(np.float32)) / 255
    label_test = torch.from_numpy(label_test.astype(np.int64))

    transform = Compose([
        ToPILImage(),
        RandomHorizontalFlip(),
        Pad(16, padding_mode='reflect'),
        RandomOrder([
            RandomRotation((-30, 30), 0.3, InterpolationMode.BICUBIC),
            RandomTranslation((8, 8), 0.3, InterpolationMode.BICUBIC),
            RandomScaling((0.9, 1.1), 0.3, InterpolationMode.BICUBIC),
            RandomShear((-30, 30, -30, 30), 0.3, InterpolationMode.BICUBIC),
        ]),
        CenterCrop(32),
        RandomOrder([
            RandomBrightness((-0.25, 0.25), 1.0),
            RandomContrast((-0.5, 0.5), 1.0),
            RandomSharpness((-0.5, 0.5), 1.0),
            RandomSaturation((-1.0, 1.0), 1.0),
            RandomSolarize((0.0, 1.0), 0.2),
            RandomPosterize((4, 8), 0.2),
            RandomEqualize(0.2),
            RandomAutocontrast(0.2),
        ]),
        ToTensor(),
        RandomErasing(1.0, (0.1, 0.1), (1, 1)),
    ])

    id = f"CIFAR-{os.path.split(os.getcwd())[-1]}-{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}"

    for seed in range(3):
        random_seed(seed)
        num_groups = 3
        num_blocks = 4
        factor_base = 16 if num_classes == 10 else 32
        factor_widen = 8
        num_feature = factor_base * pow(2, num_groups - 1) * factor_widen
        backbone = nn.Sequential(WideResNet(num_groups, num_blocks, factor_base, factor_widen, 3, momentum=0.1, dropout=0.0, global_pool='avg'))
        head = nn.Linear(num_feature, num_classes)

        model = Network(backbone, head).to(device)

        if num_classes == 10:
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: (1.0 if (step < 0.6 * total_epochs * epoch_steps) else (0.1 + 0.9 * (1.0 - ((step - 0.6 * total_epochs * epoch_steps) / ((1.0 - 0.6) * total_epochs * epoch_steps))))))
        else:
            grouped_parameters = [
                {
                    'params': [parameter for name, parameter in model.named_parameters() if not any(ban in name for ban in ['bias', 'bn'])],
                    'weight_decay': 1e-5
                },
                {
                    'params': [parameter for name, parameter in model.named_parameters() if any(ban in name for ban in ['bias', 'bn'])],
                    'weight_decay': 0.0
                },
            ]
            optimizer = torch.optim.SGD(grouped_parameters, lr=0.03, momentum=0.9, nesterov=True)
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: (1.0 if (step < 0.1 * total_epochs * epoch_steps) else 0.3**(((step - 0.1 * total_epochs * epoch_steps) // (0.1 * total_epochs * epoch_steps)) + 1)))

        experiment(id, seed, model, optimizer, scheduler, data_train, label_train, data_test, label_test, transform, device, num_samples, num_classes, num_feature, shot, total_epochs, epoch_steps, class_batch_size, weight_loss_invariance, margin, threshold_scheduler, rate_confidence_scheduler, rate_similarity_scheduler)


if __name__ == "__main__":
    main()
