import os
import math
import time
import copy
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

import torchvision
import torchvision.transforms as T
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader

from PIL import Image
from main_escnn import ESCNNEncoder
from main import adjust_learning_rate

from resnet import resnet18
from utils import knn_monitor, fix_seed, count_parameters_in_MB


class RotatedCIFAR10Dataset(Dataset):
    def __init__(self, base_dir, train=True, transform=None):
        self.base_dir = base_dir
        self.transform = transform
        self.image_dir = os.path.join(base_dir, 'train' if train else 'test')
        self.image_files = sorted(os.listdir(self.image_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        label = int(self.image_files[idx].split('_')[1])
        # angle = int(self.image_files[idx].split('_')[-1].split('.')[0])

        return image, label

normalize = T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])

def eval_loop(encoder, file_to_update, ind=None):

    if args.arch=='resnet':
        size = 32
    elif args.arch=='escnn':
        size = 33

    train_transform = T.Compose([
            T.RandomResizedCrop(size, interpolation=T.InterpolationMode.BICUBIC),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            normalize
        ])
    test_transform = T.Compose([
            T.Resize(36, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(size),
            T.ToTensor(),
            normalize
        ])

    base_save_dir = '../data/cifar10_rotated'
    train_rotated_dataset = RotatedCIFAR10Dataset(base_dir=base_save_dir, train=True, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(
            dataset=train_rotated_dataset,
            shuffle=True,
            batch_size=256,
            pin_memory=True,
            num_workers=args.num_workers,
            drop_last=True
        )

    test_rotated_dataset = RotatedCIFAR10Dataset(base_dir=base_save_dir, train=False, transform=test_transform)
    test_loader = torch.utils.data.DataLoader(
            dataset=test_rotated_dataset,
            shuffle=False,
            batch_size=256,
            pin_memory=True,
            num_workers=args.num_workers
        )
    if args.arch=='resnet':
        feature_dim = 512
    elif args.arch=='escnn':
        feature_dim = encoder.feature_dim
    classifier = nn.Linear(feature_dim, 10).cuda()
    # optimization
    optimizer = torch.optim.SGD(
        classifier.parameters(),
        momentum=0.9,
        lr=30,
        weight_decay=0
    )
    scaler = GradScaler()

    # training
    for e in range(1, 101):
        # declaring train
        classifier.train()
        encoder.eval()
        # epoch
        for it, (inputs, y) in enumerate(train_loader, start=(e - 1) * len(train_loader)):
            # adjust
            lr = adjust_learning_rate(epochs=100,
                                 warmup_epochs=0,
                                 base_lr=30,
                                 optimizer=optimizer,
                                 loader=train_loader,
                                 step=it)
            # zero grad
            classifier.zero_grad()

            def forward_step():
                with torch.no_grad():
                    b = encoder(inputs.cuda())
                logits = classifier(b)
                loss = F.cross_entropy(logits, y.cuda())
                return loss

            # optimization step
            if args.fp16:
                with autocast():
                    loss = forward_step()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss = forward_step()
                loss.backward()
                optimizer.step()

        if e % 10 == 0:
            accs = []
            classifier.eval()
            for idx, (images, labels) in enumerate(test_loader):
                with torch.no_grad():
                    if args.fp16:
                        with autocast():
                            b = encoder(images.cuda())
                            preds = classifier(b).argmax(dim=1)
                    else:
                        b = encoder(images.cuda())
                        preds = classifier(b).argmax(dim=1)
                    hits = (preds == labels.cuda()).sum().item()
                    accs.append(hits / b.shape[0])
            accuracy = np.mean(accs) * 100
            # final report of the accuracy
            line_to_print = (
                f'seed: {ind} | accuracy (%) @ epoch {e} lr {lr:.2f}: {accuracy:.2f} '
            )
            file_to_update.write(line_to_print + '\n')
            file_to_update.flush()
            print(line_to_print)

    return accuracy

def key_change(saved_dict):
    encoder_dict = dict()
    for key, val in saved_dict.items():
        if 'encoder.' in key:
            newkey = key[8:]
            encoder_dict[newkey]=val
    return encoder_dict

def main(args):
    fix_seed(args.seed)
    accs = []
    file_to_update = open(os.path.join(args.path_dir, 'four_rotated_eval.log'), 'a')
    checkpoint_path = os.path.join(args.path_dir, '800.pth')
    saved_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict']
    encoder_dict = key_change(saved_dict)

    for i in range(5):
        if args.arch=='resnet':
            eval_encoder = resnet18().cuda()
        elif args.arch=='escnn':
            eval_encoder = ESCNNEncoder(guide=args.guide, use_gpool=args.use_gpool, N=args.order).cuda()
        eval_encoder.load_state_dict(encoder_dict)
        accs.append(eval_loop(eval_encoder, file_to_update, i))
    line_to_print = f'aggregated linear probe: {np.mean(accs):.3f} +- {np.std(accs):.3f}'
    file_to_update.write(line_to_print + '\n')
    file_to_update.flush()
    print(line_to_print)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path_dir', default='../experiment/tmp', type=str)
    parser.add_argument('--arch', default='escnn', type=str, choices=['resnet', 'escnn'])
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--num_workers', default=16, type=int)
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--guide', action='store_true')
    parser.add_argument('--order', default=4, type=int)
    parser.add_argument('--use_gpool', action='store_true')
    args = parser.parse_args()

    main(args)
