import hydra
from omegaconf import DictConfig
import logging
import numpy as np
from simclr import train
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.models import resnet18, resnet34
from models import SimCLR
from tqdm import tqdm
from Data.cifar import CIFAR10

logger = logging.getLogger(__name__)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class LinModel(nn.Module):
    """Linear wrapper of encoder."""
    def __init__(self, encoder: nn.Module, feature_dim: int, n_classes: int):
        super().__init__()
        self.enc = encoder
        self.feature_dim = feature_dim
        self.n_classes = n_classes
        self.lin = nn.Linear(self.feature_dim, self.n_classes)

    def forward(self, x):
        return self.lin(self.enc(x))


def run_epoch(model, dataloader, epoch, optimizer=None, scheduler=None):

    model.train()
    loss_meter = AverageMeter('loss')
    acc_meter = AverageMeter('acc')
    loader_bar = tqdm(enumerate(dataloader))
    print(loader_bar.__len__())
    for _, (x, y, _) in loader_bar:
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        # print(y)
        filtering_score = F.mse_loss(F.softmax(logits, dim=1),
                                     torch.nn.functional.one_hot(y, num_classes=10),
                                     reduction='none')
        filtering_score = filtering_score.sum(dim=1)
        # filtering_score =  F.cross_entropy(logits, y, reduction='none')
        with torch.no_grad():
            _, index = torch.sort(filtering_score)
            index = index[: int(x.shape[0] * (1 - 0.55))]   # this line is PRL

        loss = F.cross_entropy(logits, y, reduction='none')
        loss = loss[index].mean()

        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if scheduler:
                scheduler.step()

        acc = (logits.argmax(dim=1) == y).float().mean()
        loss_meter.update(loss.item(), x.size(0))
        acc_meter.update(acc.item(), x.size(0))
        if optimizer:
            loader_bar.set_description("Train epoch {}, loss: {:.4f}, acc: {:.4f}"
                                       .format(epoch, loss_meter.avg, acc_meter.avg))
        else:
            loader_bar.set_description("Test epoch {}, loss: {:.4f}, acc: {:.4f}"
                                       .format(epoch, loss_meter.avg, acc_meter.avg))

    return loss_meter.avg, acc_meter.avg


def evaluate_epoch(model, dataloader, epoch, optimizer=None, scheduler=None):

    model.eval()

    loss_meter_clean = AverageMeter('loss_clean')
    acc_meter_clean = AverageMeter('acc_clean')
    loss_meter_poison = AverageMeter('loss_poison')
    acc_meter_poison = AverageMeter('acc_poison')
    loader_bar = tqdm(enumerate(dataloader))
    for _, ((x, x_poison), (y, y_poison), _) in loader_bar:
        x, y = x.cuda(), y.cuda()
        x_poison, y_poison = x_poison.cuda(), y_poison.cuda()
        logits_clean = model(x)
        loss_clean = F.cross_entropy(logits_clean, y)

        logits_poison = model(x_poison)
        loss_poison = F.cross_entropy(logits_poison, y)

        acc_clean = (logits_clean.argmax(dim=1) == y).float().mean()
        acc_poison = (logits_poison.argmax(dim=1) == y).float().mean()

        loss_meter_clean.update(loss_clean.item(), x.size(0))
        acc_meter_clean.update(acc_clean.item(), x.size(0))

        loss_meter_poison.update(loss_poison.item(), x.size(0))
        acc_meter_poison.update(acc_poison.item(), x.size(0))


        loader_bar.set_description("Test epoch {}, clean loss: {:.4f}, clean acc: {:.4f}, poison acc: {:.4f}"
                                       .format(epoch, loss_meter_clean.avg, acc_meter_clean.avg, acc_meter_poison.avg))

    return loss_meter_clean.avg, acc_meter_clean.avg, acc_meter_poison.avg


def get_lr(step, total_steps, lr_max, lr_min):
    """Compute learning rate according to cosine annealing schedule."""
    return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))


@hydra.main(config_path='/localscratch/liuboya2/DefenseBackdoorAttack/', config_name='simclr_config.yml')
def finetune(args: DictConfig, corruption_ratio=0.45, corruption_type='blend') -> None:
    train_transform = transforms.Compose([
                                          transforms.RandomResizedCrop(32),
                                          transforms.RandomHorizontalFlip(p=0.5),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])


    test_transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ]
    )


    data_dir = hydra.utils.to_absolute_path(args.data_dir)
    train_set = CIFAR10(root='./data/',
                            download=True,
                            train=True,
                            transform=train_transform,
                            noise_type=corruption_type,
                            noise_rate=corruption_ratio,
                            )

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              drop_last=True)


    # test_set = CIFAR10(root=data_dir, train=False, transform=test_transform, download=False)

    test_set = CIFAR10(root='./data/',
                            download=True,
                            train=False,
                            transform=test_transform,
                            noise_type=corruption_type,
                            noise_rate=corruption_ratio,
                            )

    # train_loader = DataLoader(train_set,
    #                           batch_size=args.batch_size,
    #                           shuffle=True,
    #                           num_workers=args.workers,
    #                           drop_last=True)
    n_classes = 10
    # indices = np.random.choice(len(train_set), 10*n_classes, replace=False)
    # sampler = SubsetRandomSampler(indices)
    # train_loader = DataLoader(train_set, batch_size=args.batch_size, drop_last=True, sampler=sampler)
    test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)

    # Prepare model
    base_encoder = eval(args.backbone)
    pre_model = SimCLR(base_encoder, projection_dim=args.projection_dim).cuda()
    pre_model.load_state_dict(torch.load('simclr_{}_epoch{}.pt'.format(args.backbone, args.load_epoch)))
    # pre_model.load_state_dict(torch.load('simclr_{}_epoch500.pt'.format(args.backbone, args.load_epoch)))
    # pre_model.load_state_dict(torch.load('simclr_lin_resnet18_best.pth'))
    model = LinModel(pre_model.enc, feature_dim=pre_model.feature_dim, n_classes=10)
    model = model.cuda()

    # Fix encoder
    model.enc.requires_grad = False
    parameters = [param for param in model.parameters() if param.requires_grad is True]  # trainable parameters.
    # optimizer = Adam(parameters, lr=0.001)

    optimizer = torch.optim.SGD(
        parameters,
        0.2,   # lr = 0.1 * batch_size / 256, see section B.6 and B.7 of SimCLR paper.
        momentum=args.momentum,
        weight_decay=0.001,
        nesterov=True)

    # cosine annealing lr
    scheduler = LambdaLR(
        optimizer,
        lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
            step,
            args.epochs * len(train_loader.dataset),
            args.learning_rate,  # lr_lambda computes multiplicative factor
            1e-3))

    optimal_loss, optimal_acc, optimal_poison_acc = 1e5, 0., 0.
    for epoch in range(1, args.finetune_epochs + 1):
        train_loss, train_acc = run_epoch(model, train_loader, epoch, optimizer, scheduler)
        test_loss, test_acc, test_poison_acc = evaluate_epoch(model, test_loader, epoch)

        if train_loss < optimal_loss:
            optimal_loss = train_loss
            optimal_acc = test_acc
            optimal_poison_acc = test_poison_acc
            logger.info("==> New best results")
            torch.save(model.state_dict(), 'simclr_lin_{}_best.pth'.format(args.backbone))

        print("Current Best Test Acc: {:.4f}|Poison Acc: {:.4f}".format(optimal_acc, optimal_poison_acc))
    logger.info("Best Test Acc: {:.4f}|Poison Acc: {:.4f}".format(optimal_acc, optimal_poison_acc))


if __name__ == '__main__':
    finetune() # make sure train set is the same (i.e. will not corrupt different data point)


