import os
import math
import time
import copy
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

import torchvision
import torchvision.transforms as T
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader

from PIL import Image
from main_escnn import ESCNNEncoder
from main import adjust_learning_rate

from resnet import resnet18
from utils import knn_monitor, fix_seed, count_parameters_in_MB, CircularCrop


normalize = T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])

def eval_loop(encoder, file_to_update, ind=None):

    if args.arch=='resnet':
        size = 32
    elif args.arch=='escnn':
        size = 33
        
    train_transform = T.Compose([
            T.RandomResizedCrop(size),
            CircularCrop(0),
            T.RandomHorizontalFlip(),
            T.RandomRotation(degrees=(0,360), interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),
            normalize
        ])
    test_transform = T.Compose([
            T.Resize(36),
            T.CenterCrop(size),
            CircularCrop(0),
            T.RandomRotation(degrees=(0,0), interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),
            normalize
        ])

    train_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.CIFAR10('../data', train=True, transform=train_transform, download=True),
        shuffle=True,
        batch_size=256,
        pin_memory=True,
        num_workers=args.num_workers,
        drop_last=True
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.CIFAR10('../data', train=False, transform=test_transform, download=True),
        shuffle=False,
        batch_size=256,
        pin_memory=True,
        num_workers=args.num_workers
    )

    if args.arch=='resnet':
        feature_dim = 512
    elif args.arch=='escnn':
        feature_dim = encoder.feature_dim
    classifier = nn.Linear(feature_dim, 10).cuda()
    # optimization

    optimizer = torch.optim.SGD(
        classifier.parameters(),
        momentum=0.9,
        lr=5,
        weight_decay=0
    )
    scaler = GradScaler()

    # training
    for e in range(1, 401):
        # declaring train
        classifier.train()
        encoder.eval()
        # epoch
        for it, (inputs, y) in enumerate(train_loader, start=(e - 1) * len(train_loader)):
            # adjust
            lr = adjust_learning_rate(epochs=400,
                                 warmup_epochs=0,
                                 base_lr=5,
                                 optimizer=optimizer,
                                 loader=train_loader,
                                 step=it)
            # zero grad
            classifier.zero_grad()

            def forward_step():
                with torch.no_grad():
                    b = encoder(inputs.cuda())
                logits = classifier(b)
                loss = F.cross_entropy(logits, y.cuda())
                return loss

            # optimization step
            if args.fp16:
                with autocast():
                    loss = forward_step()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss = forward_step()
                loss.backward()
                optimizer.step()

        if e % 10 == 0:
            accs = []
            classifier.eval()
            torch.save(classifier.state_dict(), os.path.join(args.path_dir, 'classifier_state_dict.pth'))

            for idx, (images, labels) in enumerate(test_loader):
                with torch.no_grad():
                    if args.fp16:
                        with autocast():
                            b = encoder(images.cuda())
                            preds = classifier(b).argmax(dim=1)
                    else:
                        b = encoder(images.cuda())
                        preds = classifier(b).argmax(dim=1)
                    hits = (preds == labels.cuda()).sum().item()
                    accs.append(hits / b.shape[0])
            accuracy = np.mean(accs) * 100
            # final report of the accuracy
            line_to_print = (
                f'seed: {ind} | accuracy (%) @ epoch {e} lr {lr:.2f}: {accuracy:.2f} '
            )
            file_to_update.write(line_to_print + '\n')
            file_to_update.flush()
            print(line_to_print)

    return accuracy, classifier

def rotated_infer_loop(encoder, classifier, file_to_update):
    encoder.eval()
    classifier.eval()

    if args.arch=='resnet':
        size = 32
    elif args.arch=='escnn':
        size = 33

    accuracys = []

    for i in range(72):

        degree = 5 * i

        test_transform = T.Compose([
            T.Resize(36),
            T.CenterCrop(size),
            CircularCrop(0),
            T.RandomRotation((degree, degree), interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),
            normalize
        ])


        test_dataset = torchvision.datasets.CIFAR10('../data', train=False, transform=test_transform, download=True)

        test_loader = torch.utils.data.DataLoader(
            dataset=test_dataset,
            shuffle=False,
            batch_size=256,
            pin_memory=True,
            num_workers=args.num_workers)

        accs = []
        for idx, (images, labels) in enumerate(test_loader):
            with torch.no_grad():
                if args.fp16:
                    with autocast():
                        b = encoder(images.cuda())
                        preds = classifier(b).argmax(dim=1)
                else:
                    b = encoder(images.cuda())
                    preds = classifier(b).argmax(dim=1)
                hits = (preds == labels.cuda()).sum().item()
                accs.append(hits / b.shape[0])

        accuracy = np.mean(accs) * 100
        # final report of the accuracy
        line_to_print = (
            f'degree: {degree} | accuracy (%) : {accuracy:.2f} '
        )
        file_to_update.write(line_to_print + '\n')
        file_to_update.flush()
        print(line_to_print)

        accuracys.append(accuracy)
    
    return accuracys


def key_change(saved_dict):
    encoder_dict = dict()
    for key, val in saved_dict.items():
        if 'encoder.' in key:
            newkey = key[8:]
            encoder_dict[newkey]=val
    return encoder_dict


def main(args):
    fix_seed(args.seed)
    file_to_update = open(os.path.join(args.path_dir, 'rotated_eval.log'), 'a')
    checkpoint_path = os.path.join(args.path_dir, '800.pth')
    saved_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict']
    encoder_dict = key_change(saved_dict)

    if args.arch=='resnet':
        eval_encoder = resnet18().cuda()
    elif args.arch=='escnn':
        eval_encoder = ESCNNEncoder(guide=args.guide, use_gpool=args.use_gpool, N=args.order).cuda()
    eval_encoder.load_state_dict(encoder_dict)

    acc, classifier = eval_loop(eval_encoder, file_to_update, 0)
    line_to_print = f'linear probe: {acc:.3f}'
    file_to_update.write(line_to_print + '\n')
    file_to_update.flush()
    print(line_to_print)

    accs = rotated_infer_loop(eval_encoder, classifier, file_to_update)

    line_to_print = f'aggregated linear probe: {np.mean(accs):.3f} +- {np.std(accs):.3f}'
    file_to_update.write(line_to_print + '\n')
    file_to_update.flush()
    print(line_to_print)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path_dir', default='../experiment/tmp', type=str)
    parser.add_argument('--arch', default='escnn', type=str, choices=['resnet', 'escnn'])
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--num_workers', default=16, type=int)
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--guide', action='store_true')
    parser.add_argument('--order', default=4, type=int)
    parser.add_argument('--use_gpool', action='store_true')
    args = parser.parse_args()

    main(args)
