import os
import warnings
from typing import List
from tqdm import tqdm

import numpy as np
import torch
import torchvision
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage, NormalizeImage, ModuleWrapper
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter

from main.utils.random import set_random_seed
from main.models.cifar10 import init_resnet18


CIFAR_MEAN = np.array([125.307, 122.961, 113.8575])
CIFAR_STD = np.array([51.5865, 50.847, 51.255])


def evaluate_cifar10_subset(indices: List[int]=np.arange(50000),
                            model=None,
                            optimizer=None,
                            batch_size=128,
                            num_epochs=200,
                            num_workers=8,
                            device=None, 
                            seed=None,
                            return_model=False):
    warnings.filterwarnings("ignore")
    
    if not device:
        device = torch.device('cpu')
    if seed:
        set_random_seed(seed)
    
    dataset_path = os.path.join(os.path.dirname(__file__), '..', '..', 'datasets')
    train_set_path = os.path.join(dataset_path, 'cifar10_train.beton')
    test_set_path = os.path.join(dataset_path, 'cifar10_test.beton')
    if not (os.path.exists(train_set_path) and os.path.exists(test_set_path)):
        datasets = {
            'train': torchvision.datasets.CIFAR10(dataset_path, train=True, download=True),
            'test': torchvision.datasets.CIFAR10(dataset_path, train=False, download=True)
        }

        for (name, ds) in datasets.items():
            writer = DatasetWriter(f'datasets/cifar10_{name}.beton', {
                'image': RGBImageField(),
                'label': IntField()
            })
            writer.from_indexed_dataset(ds)

    train_image_pipeline: List[Operation] = [SimpleRGBImageDecoder(),
                                             RandomTranslate(padding=4),
                                             RandomHorizontalFlip(),
                                             NormalizeImage(CIFAR_MEAN, CIFAR_STD, np.float32),
                                             ToTensor(),
                                             ToDevice(device, non_blocking=True),
                                             ToTorchImage()]
    train_label_pipeline: List[Operation] = [IntDecoder(),
                                             ToTensor(),
                                             ToDevice(device),
                                             Squeeze()]
    train_loader = Loader(train_set_path,
                          batch_size=batch_size, 
                          num_workers=num_workers,
                          order=OrderOption.RANDOM,
                          seed=seed,
                          drop_last=False,
                          pipelines={'image': train_image_pipeline,
                                     'label': train_label_pipeline},
                          indices=indices)
    
    test_image_pipeline: List[Operation] = [SimpleRGBImageDecoder(),
                                            NormalizeImage(CIFAR_MEAN, CIFAR_STD, np.float32),
                                            ToTensor(),
                                            ToDevice(device, non_blocking=True),
                                            ToTorchImage()]
    test_label_pipeline: List[Operation] = [IntDecoder(),
                                            ToTensor(),
                                            ToDevice(device),
                                            Squeeze()]
    test_loader = Loader(test_set_path,
                         batch_size=batch_size, 
                         num_workers=num_workers,
                         order=OrderOption.SEQUENTIAL,
                         drop_last=False,
                         pipelines={'image': test_image_pipeline,
                                    'label': test_label_pipeline})

    if not model:
        model = init_resnet18().to(device, memory_format=torch.channels_last)
        optimizer = SGD(model.parameters(), lr=.1, momentum=.9, weight_decay=5e-4)
    
    iters_per_epoch = 50000 // batch_size
    scheduler = CosineAnnealingLR(optimizer,
                                  T_max=num_epochs,
                                  eta_min=0.0)
    scaler = GradScaler()
    loss_fn = CrossEntropyLoss()

    for ep in tqdm(range(num_epochs), desc="Training"):
        for ims, labs in train_loader:
            optimizer.zero_grad(set_to_none=True)
            with autocast():
                out = model(ims)
                loss = loss_fn(out, labs)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        scheduler.step()

    model.eval()
    with torch.no_grad():
        total_correct, total_num = 0., 0.
        for ims, labs in test_loader:
            with autocast():
                out = model(ims)
                total_correct += out.argmax(1).eq(labs).sum().cpu().item()
                total_num += ims.shape[0]

        accuracy = total_correct / total_num * 100

    if return_model:
        return accuracy, model

    return accuracy
