import numpy as np
import torch
import torch.nn as nn
from utils.misc import *


def load_resnet50(net, head, ssh, classifier, args):

    if args.ckpt:
        filename = args.resume + '/ckpt_epoch_{:d}.pth'.format(args.ckpt)
    else:
        filename = args.resume + '/ckpt.pth'
    ckpt = torch.load(filename)
    state_dict = ckpt['model']

    net_dict = {}
    head_dict = {}
    for k, v in state_dict.items():
        # print(11111)
        # print(k)
        if k[:4] == "head":
            k = k.replace("head.", "")
            head_dict[k] = v
        else:
            k = k.replace("encoder.", "ext.")
            k = k.replace("fc.", "head.fc.")
            net_dict[k] = v

    net.load_state_dict(net_dict)
    head.load_state_dict(head_dict)

    print('Loaded model trained jointly on Classification and SimCLR:', filename)

# def load_resnet50(net, head, ssh, classifier, args):

#     if args.ckpt:
#         filename = args.resume + '/ckpt_epoch_{:d}.pth'.format(args.ckpt)
#     else:
#         filename = args.resume + '/ckpt.pth'
#     ckpt = torch.load(filename)
#     state_dict = ckpt['model']

#     net_dict = {}
#     # print(22222)
#     # print(net)
#     head_dict = {}
#     for k, v in state_dict.items():
#         # print(11111)
#         # print(k)
#         if k[:4] == "head":
#             k = k.replace("head.", "")
#             head_dict[k] = v
#         else:
#             k = k.replace("encoder.", "ext.")
#             k = k.replace("fc.", "head.fc.")
#             net_dict[k] = v

#     net.load_state_dict(net_dict)
#     # head.load_state_dict(head_dict)

#     print('Loaded model trained nossh on Classification and SimCLR:', filename)


def load_ttt(net, head, ssh, classifier, args, ttt=False):
    if ttt:
        filename = args.resume + '/{}_both_2_15.pth'.format(args.corruption)
    else:
        filename = args.resume + '/{}_both_15.pth'.format(args.corruption)
    ckpt = torch.load(filename)
    net.load_state_dict(ckpt['net'])
    head.load_state_dict(ckpt['head'])
    print('Loaded updated model from', filename)


def corrupt_resnet50(ext, args):
    try:
        # SSL trained encoder
        simclr = torch.load(args.restore + '/simclr.pth')
        state_dict = simclr['model']

        ext_dict = {}
        for k, v in state_dict.items():
            if k[:7] == "encoder":
                k = k.replace("encoder.", "")
                ext_dict[k] = v
        ext.load_state_dict(ext_dict)

        print('Corrupted encoder trained by SimCLR')

    except:
        # Jointly trained encoder
        filename = args.resume + '/ckpt_epoch_{}.pth'.format(args.restore)

        ckpt = torch.load(filename)
        state_dict = ckpt['model']

        ext_dict = {}
        for k, v in state_dict.items():
            if k[:7] == "encoder":
                k = k.replace("encoder.", "")
                ext_dict[k] = v
        # import pdb; pdb.set_trace()
        # print_params(ext)
        ext.load_state_dict(ext_dict)
        print('Corrupted encoder jontly trained on Classification and SimCLR')


def build_resnet50(args):
    from models.BigResNet import SupConResNet, LinearClassifier
    from models.SSHead import ExtractorHead

    print('Building ResNet50...')
    if args.dataset == 'cifar10':
        classes = 10
    elif args.dataset == 'cifar7':
        if not hasattr(args, 'modified') or args.modified:
            classes = 7
        else:
            classes = 10
    elif args.dataset == "cifar100":
        classes = 100

    classifier = LinearClassifier(num_classes=classes).cuda() # classifier
    ssh = SupConResNet().cuda() # backbone + projection head
    head = ssh.head # projection head
    ext = ssh.encoder # backbone
    net = ExtractorHead(ext, classifier).cuda() # backbone + classifier
    return net, ext, head, ssh, classifier

def build_slim_resnet50(args):
    from models.BigResNet import SupConSlimResNet, LinearClassifier
    from models.SSHead import ExtractorHead

    print('Building Slim_ResNet50...')
    if args.dataset == 'cifar10':
        classes = 10
    elif args.dataset == 'cifar7':
        if not hasattr(args, 'modified') or args.modified:
            classes = 7
        else:
            classes = 10
    elif args.dataset == "cifar100":
        classes = 100

    classifier = LinearClassifier(num_classes=classes).cuda() # classifier
    ssh = SupConSlimResNet().cuda() # backbone + projection head
    head = ssh.head # projection head
    ext = ssh.encoder # backbone
    net = ExtractorHead(ext, classifier).cuda() # backbone + classifier
    return net, ext, head, ssh, classifier

def build_model(args):
    from models.ResNet import ResNetCifar as ResNet
    from models.SSHead import ExtractorHead
    print('Building model...')
    if args.dataset == 'cifar10':
        classes = 10
    elif args.dataset == 'cifar7':
        if not hasattr(args, 'modified') or args.modified:
            classes = 7
        else:
            classes = 10
    elif args.dataset == "cifar100":
        classes = 100

    if args.group_norm == 0:
        norm_layer = nn.BatchNorm2d
    else:
        def gn_helper(planes):
            return nn.GroupNorm(args.group_norm, planes)
        norm_layer = gn_helper

    if hasattr(args, 'detach') and args.detach:
        detach = args.shared
    else:
        detach = None
    net = ResNet(args.depth, args.width, channels=3, classes=classes, norm_layer=norm_layer, detach=detach).cuda()
    if args.shared == 'none':
        args.shared = None

    if args.shared == 'layer3' or args.shared is None:
        from models.SSHead import extractor_from_layer3
        ext = extractor_from_layer3(net)
        if not hasattr(args, 'ssl') or args.ssl == 'rotation':
            head = nn.Linear(64 * args.width, 4)
        elif args.ssl == 'contrastive':
            head = nn.Sequential(
                nn.Linear(64 * args.width, 64 * args.width),
                nn.ReLU(inplace=True),
                nn.Linear(64 * args.width, 16 * args.width)
            )
        else:
            raise NotImplementedError
    elif args.shared == 'layer2':
        from models.SSHead import extractor_from_layer2, head_on_layer2
        ext = extractor_from_layer2(net)
        head = head_on_layer2(net, args.width, 4)
    ssh = ExtractorHead(ext, head).cuda()

    if hasattr(args, 'parallel') and args.parallel:
        net = torch.nn.DataParallel(net)
        ssh = torch.nn.DataParallel(ssh)
    return net, ext, head, ssh


def test(dataloader, model, **kwargs):
    criterion = nn.CrossEntropyLoss(reduction='none').cuda()
    model.eval()
    correct = []
    losses = []
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        if type(inputs) == list:
            inputs = inputs[0]
        inputs, labels = inputs.cuda(), labels.cuda()
        with torch.no_grad():
            outputs = model(inputs, **kwargs)
            loss = criterion(outputs, labels)
            losses.append(loss.cpu())
            _, predicted = outputs.max(1)
            # print(66666)
            # print(predicted)
            correct.append(predicted.eq(labels).cpu())
    correct = torch.cat(correct).numpy()
    losses = torch.cat(losses).numpy()
    model.train()
    return 1-correct.mean(), correct, losses

# def test_ensemble(dataloader, model, width_mult, **kwargs):
#     criterion = nn.CrossEntropyLoss(reduction='none').cuda()
#     model.eval()
#     correct = []
#     losses = []
#     for batch_idx, (inputs, labels) in enumerate(dataloader):
#         if type(inputs) == list:
#             inputs = inputs[0]
#         inputs, labels = inputs.cuda(), labels.cuda()
#         with torch.no_grad():
#             if width_mult == 1.0:
#                 widths_train = [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
#             elif width_mult == 0.875:
#                 widths_train = [0.25, 0.375, 0.5, 0.625, 0.75, 0.875]
#             elif width_mult == 0.75:
#                 widths_train = [0.25, 0.375, 0.5, 0.625, 0.75]
#             elif width_mult == 0.625:
#                 widths_train = [0.25, 0.375, 0.5, 0.625]
#             elif width_mult == 0.5:
#                 widths_train = [0.25, 0.375, 0.5]
#             elif width_mult == 0.375:
#                 widths_train = [0.25, 0.375]
#             for width_mult in widths_train:
#                 model.apply(
#                     lambda m: setattr(m, 'width_mult', width_mult))
#                 if width_mult == 0.25:
#                     outputs1 = model(inputs, **kwargs)
#                 elif width_mult == 0.375:
#                     outputs2 = model(inputs, **kwargs)
#                 elif width_mult == 0.5:
#                     outputs3 = model(inputs, **kwargs)
#                 elif width_mult == 0.625:
#                     outputs4 = model(inputs, **kwargs)
#                 elif width_mult == 0.75:
#                     outputs5 = model(inputs, **kwargs)
#                 elif width_mult == 0.875:
#                     outputs6 = model(inputs, **kwargs)
#                 else:
#                     outputs7 = model(inputs, **kwargs)
#             # outputs = model(inputs, **kwargs)
#             if width_mult == 1.0:
#                 outputs = (outputs1 + outputs2 + outputs3 + outputs4 + outputs5 + outputs6 + outputs7) / 7.0
#                 # outputs = outputs1 * 0.1 + outputs2 * 0.2 + outputs3 * 0.3 + outputs4 * 0.4
#             elif width_mult == 0.875:
#                 outputs = (outputs1 + outputs2 + outputs3 + outputs4 + outputs5 + outputs6) / 6.0
#             elif width_mult == 0.75:
#                 outputs = (outputs1 + outputs2 + outputs3 + outputs4 + outputs5) / 5.0
#                 # outputs = outputs1 * (1/6) + outputs2 * (1/3) + outputs3 * (1/2)
#             elif width_mult == 0.625:
#                 outputs = (outputs1 + outputs2 + outputs3 + outputs4) / 4.0
#             elif width_mult == 0.5:
#                 outputs = (outputs1 + outputs2 + outputs3) / 3.0
#                 # outputs = outputs1 * (1/3) + outputs2 * (2/3)
#             elif width_mult == 0.375:
#                 outputs = (outputs1 + outputs2) / 2.0
#             # print(66666)
#             # print(outputs)
#             loss = criterion(outputs, labels)
#             losses.append(loss.cpu())
#             _, predicted = outputs.max(1)
#             # print(66666)
#             # print(predicted)
#             correct.append(predicted.eq(labels).cpu())
#     correct = torch.cat(correct).numpy()
#     losses = torch.cat(losses).numpy()
#     model.train()
#     return 1-correct.mean(), correct, losses

def test_ensemble(dataloader, model, width_mult, **kwargs):
    criterion = nn.CrossEntropyLoss(reduction='none').cuda()
    model.eval()
    correct = []
    losses = []
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        if type(inputs) == list:
            inputs = inputs[0]
        inputs, labels = inputs.cuda(), labels.cuda()
        with torch.no_grad():
            if width_mult == 1.0:
                widths_train = [0.25, 0.5, 0.75, 1.0]
            elif width_mult == 0.75:
                widths_train = [0.25, 0.5, 0.75]
            elif width_mult == 0.5:
                widths_train = [0.25, 0.5]
            for width_mult in widths_train:
                model.apply(
                    lambda m: setattr(m, 'width_mult', width_mult))
                if width_mult == 0.25:
                    outputs1 = model(inputs, **kwargs)
                    # output1_prob = torch.softmax(outputs1, dim=1)
                    # entropy1 = -torch.sum(output1_prob * torch.log(output1_prob + 1e-10), dim=1)
                    # print(11111)
                    # print(outputs1.size())
                    # print(22222)
                    # print(output1_prob.size())
                    # print(33333)
                    # print(entropy1.size())
                    # print(entropy1)
                    # _, predicted = outputs1.max(1)
                    # print(predicted)
                    # print(labels)
                    # print(11111)
                    # print(outputs1)
                elif width_mult == 0.5:
                    outputs2 = model(inputs, **kwargs)
                    # output2_prob = torch.softmax(outputs2, dim=1)
                    # entropy2 = -torch.sum(output2_prob * torch.log(output2_prob + 1e-10), dim=1)
                    # print(11111)
                    # print(outputs2.size())
                    # print(22222)
                    # print(output2_prob.size())
                    # print(33333)
                    # print(entropy2.size())
                    # print(entropy2)
                    # _, predicted = outputs2.max(1)
                    # print(predicted)
                    # print(labels)
                    # print(22222)
                    # print(outputs2)
                elif width_mult == 0.75:
                    outputs3 = model(inputs, **kwargs)
                    # output3_prob = torch.softmax(outputs3, dim=1)
                    # entropy3 = -torch.sum(output3_prob * torch.log(output3_prob + 1e-10), dim=1)
                    # print(11111)
                    # print(outputs3.size())
                    # print(22222)
                    # print(output3_prob.size())
                    # print(33333)
                    # print(entropy3.size())
                    # print(entropy3)
                    # _, predicted = outputs3.max(1)
                    # print(predicted)
                    # print(labels)
                    # print(33333)
                    # print(outputs3)
                else:
                    outputs4 = model(inputs, **kwargs)
                    # output4_prob = torch.softmax(outputs4, dim=1)
                    # entropy4 = -torch.sum(output4_prob * torch.log(output4_prob + 1e-10), dim=1)
                    # print(11111)
                    # print(outputs4.size())
                    # print(22222)
                    # print(output4_prob.size())
                    # print(33333)
                    # print(entropy4.size())
                    # print(entropy4)
                    # _, predicted = outputs4.max(1)
                    # print(predicted)
                    # print(labels)
                    # print(44444)
                    # print(outputs4)
            # outputs = model(inputs, **kwargs)
            if width_mult == 1.0:
                outputs = (outputs1 + outputs2 + outputs3 + outputs4) / 4.0
                # outputs = outputs1 * 0.1 + outputs2 * 0.2 + outputs3 * 0.3 + outputs4 * 0.4
            elif width_mult == 0.75:
                outputs = (outputs1 + outputs2 + outputs3) / 3.0
                # outputs = outputs1 * (1/6) + outputs2 * (1/3) + outputs3 * (1/2)
            elif width_mult == 0.5:
                outputs = (outputs1 + outputs2) / 2.0
                # outputs = outputs1 * (1/3) + outputs2 * (2/3)
            # print(66666)
            # print(outputs)
            loss = criterion(outputs, labels)
            losses.append(loss.cpu())
            _, predicted = outputs.max(1)
            # print(66666)
            # print(predicted)
            correct.append(predicted.eq(labels).cpu())
    correct = torch.cat(correct).numpy()
    losses = torch.cat(losses).numpy()
    model.train()
    return 1-correct.mean(), correct, losses

def pair_buckets(o1, o2):
    crr = np.logical_and( o1, o2 )
    crw = np.logical_and( o1, np.logical_not(o2) )
    cwr = np.logical_and( np.logical_not(o1), o2 )
    cww = np.logical_and( np.logical_not(o1), np.logical_not(o2) )
    return crr, crw, cwr, cww


def count_each(tuple):
    return [item.sum() for item in tuple]


def plot_epochs(all_err_cls, all_err_ssh, fname, use_agg=True):
    import matplotlib.pyplot as plt
    if use_agg:
        plt.switch_backend('agg')

    plt.plot(np.asarray(all_err_cls)*100, color='r', label='classifier')
    plt.plot(np.asarray(all_err_ssh)*100, color='b', label='self-supervised')
    plt.xlabel('epoch')
    plt.ylabel('test error (%)')
    plt.legend()
    plt.savefig(fname)
    plt.close()


@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

