import os
import shutil

import torch

from tqdm import tqdm
import numpy as np
import torch.nn.functional as F


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


def save_config_file(model_checkpoints_folder, args):
    if not os.path.exists(model_checkpoints_folder):
        os.makedirs(model_checkpoints_folder)
        with open(os.path.join(model_checkpoints_folder, 'config.yml'),
                  'w') as outfile:
            yaml.dump(args, outfile, default_flow_style=False)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def load_victim(epochs, dataset, model, arch, loss, device, discard_mlp=False,
                watermark="False", entropy="False"):
    if watermark == "True":
        checkpoint = torch.load(
            f"/checkpoint/{os.getenv('USER')}/SimCLR/{epochs}{arch}{loss}TRAIN/{dataset}_checkpoint_{epochs}_{loss}WATERMARK.pth.tar",
            map_location=device)
    elif entropy == "True":
        checkpoint = torch.load(
            f"/checkpoint/{os.getenv('USER')}/SimCLR/{epochs}{arch}{loss}TRAIN/{dataset}_checkpoint_{epochs}_{loss}ENTROPY.pth.tar",
            map_location=device)
    else:
        checkpoint = torch.load(
            f"/checkpoint/{os.getenv('USER')}/SimCLR/{epochs}{arch}{loss}TRAIN/{dataset}_checkpoint_{epochs}_{loss}.pth.tar",
            map_location=device)
    state_dict = checkpoint['state_dict']
    new_state_dict = state_dict.copy()
    if discard_mlp:  # no longer necessary as the model architecture has no backbone.fc layers
        for k in list(state_dict.keys()):
            if k.startswith('backbone.fc'):
                del new_state_dict[k]
        model.load_state_dict(new_state_dict, strict=False)
        return model
    model.load_state_dict(state_dict, strict=False)
    return model


def load_watermark(epochs, dataset, model, arch, loss, device):
    checkpoint = torch.load(
        f"/checkpoint/{os.getenv('USER')}/SimCLR/{epochs}{arch}{loss}TRAIN/{dataset}_checkpoint_{epochs}_{loss}WATERMARK.pth.tar",
        map_location=device)
    try:
        state_dict = checkpoint['watermark_state_dict']
    except:
        state_dict = checkpoint['mlp_state_dict']

    model.load_state_dict(state_dict)
    return model


def print_args(args, get_str=False):
    if "delimiter" in args:
        delimiter = args.delimiter
    elif "sep" in args:
        delimiter = args.sep
    else:
        delimiter = ";"
    print("###################################################################")
    print("args: ")
    keys = sorted(
        [
            a
            for a in dir(args)
            if not (
                a.startswith("__")
                or a.startswith("_")
                or a == "sep"
                or a == "delimiter"
        )
        ]
    )
    values = [getattr(args, key) for key in keys]
    if get_str:
        keys_str = delimiter.join([str(a) for a in keys])
        values_str = delimiter.join([str(a) for a in values])
        print(keys_str)
        print(values_str)
        return keys_str, values_str
    else:
        for key, value in zip(keys, values):
            print(key, ": ", value, flush=True)
    print("ARGS FINISHED", flush=True)
    print("######################################################")


def get_query_label_dataset(victim_model, dataset, query_num, args):
    if len(dataset) < query_num:
        print("Query number is greater than length of dataset")
        raise ValueError
    query_dataset = torch.utils.data.Subset(dataset, range(0, query_num))
    assert len(query_dataset) == query_num

    query_loader = torch.utils.data.DataLoader(
        query_dataset, batch_size=256,
        shuffle=False, #(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=None, drop_last=False)

    query_label_dataset = torch.utils.data.TensorDataset(torch.tensor([]), torch.tensor([]))

    victim_model.eval()
    with torch.no_grad():
        for i, (images, _) in tqdm(enumerate(query_loader)):
            images = images.cuda()
            victim_features = victim_model.get_image_representation(images, n=args.last_n_blocks, avgpool=False)
            images = images.detach().cpu()
            victim_features = victim_features.detach().cpu()
            subdataset = torch.utils.data.TensorDataset(images, victim_features)
            query_label_dataset = torch.utils.data.ConcatDataset((query_label_dataset, subdataset))

    assert len(query_label_dataset) == query_num
    return query_label_dataset


class TransformTwice:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        return out1, out2


class TransformThreeTimes:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        out3 = self.transform(inp)
        return out1, out2, out3


class TransformFourTimes:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        out3 = self.transform(inp)
        out4 = self.transform(inp)
        return out1, out2, out3, out4


def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]


def linear_rampup(current, rampup_length=None):
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)


class SemiLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, args):
        probs_u = torch.softmax(outputs_u, dim=1)

        Lx = -torch.mean(
            torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u) ** 2)

        return Lx, Lu, args.lambda_u * linear_rampup(epoch, args.epochs)


class SemiCELoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, args):
        probs_u = torch.softmax(outputs_u, dim=1)

        Lx = -torch.mean(
            torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = -torch.mean(
            torch.sum(F.log_softmax(outputs_u, dim=1) * targets_u, dim=1))

        return Lx, Lu, args.lambda_u * linear_rampup(epoch, args.epochs)


class SemiMSELoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, args):

        Lx = torch.mean((outputs_x - targets_x) ** 2)
        Lu = torch.mean((outputs_u - targets_u) ** 2)

        return Lx, Lu, args.lambda_u * linear_rampup(epoch, args.epochs)


def mix_match_K2(inputs_x, targets_x, unlabeled_dataloader, model, current_epoch, args):
    if not args.mixmatch_mse:
        targets_x = torch.nn.functional.softmax(targets_x, dim=1)
    batch_size = inputs_x.size(0)
    try:
        (inputs_u, inputs_u2), _ = unlabeled_iter.next()
    except:
        unlabeled_iter = iter(unlabeled_dataloader)
        (inputs_u, inputs_u2), _ = unlabeled_iter.next()
    inputs_u = inputs_u.cuda()
    inputs_u2 = inputs_u2.cuda()
    with torch.no_grad():
        # compute guessed labels of unlabeled samples
        outputs_u = model.module.get_image_representation(inputs_u, n=args.last_n_blocks, avgpool=False)
        outputs_u2 = model.module.get_image_representation(inputs_u2, n=args.last_n_blocks, avgpool=False)
        if not args.mixmatch_mse:
            p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2,
                                                                 dim=1)) / 2  # qb (before qb.cpu called)
            pt = p ** (1 / args.T)  # temp2 (just written in a different way)
            targets_u = pt / pt.sum(dim=1, keepdim=True)  # qb after total
        else:
            targets_u = (outputs_u + outputs_u2) / 2
        targets_u = targets_u.detach()  # removes tracking of gradients for this
    all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
    all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)
    l = np.random.beta(args.alpha, args.alpha)
    l = max(l, 1 - l)

    idx = torch.randperm(all_inputs.size(0))

    input_a, input_b = all_inputs, all_inputs[idx]
    target_a, target_b = all_targets, all_targets[idx]

    mixed_input = l * input_a + (1 - l) * input_b
    mixed_target = l * target_a + (1 - l) * target_b

    # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
    mixed_input = list(torch.split(mixed_input, batch_size))
    mixed_input = interleave(mixed_input, batch_size)

    logits = [model.module.get_image_representation(mixed_input[0], n=args.last_n_blocks, avgpool=False)]
    for input in mixed_input[1:]:
        logits.append(model.module.get_image_representation(input, n=args.last_n_blocks, avgpool=False))

    logits = interleave(logits, batch_size)
    logits_x = logits[0]
    logits_u = torch.cat(logits[1:], dim=0)

    if not args.mixmatch_mse:
        loss_function = SemiLoss()
    else:
        loss_function = SemiMSELoss()
    Lx, Lu, w = loss_function(logits_x, mixed_target[:batch_size], logits_u,
                              mixed_target[batch_size:], current_epoch
                              , args)
    loss = Lx + w * Lu
    return loss


def mix_match_K3(inputs_x, targets_x, unlabeled_dataloader, model, current_epoch, args):
    if not args.mixmatch_mse:
        targets_x = torch.nn.functional.softmax(targets_x, dim=1)
    batch_size = inputs_x.size(0)
    try:
        (inputs_u, inputs_u2, inputs_u3), _ = unlabeled_iter.next()
    except:
        unlabeled_iter = iter(unlabeled_dataloader)
        (inputs_u, inputs_u2, inputs_u3), _ = unlabeled_iter.next()
    inputs_u = inputs_u.cuda()
    inputs_u2 = inputs_u2.cuda()
    inputs_u3 = inputs_u3.cuda()
    with torch.no_grad():
        # compute guessed labels of unlabeled samples
        outputs_u = model.module.get_image_representation(inputs_u, n=args.last_n_blocks, avgpool=False)
        outputs_u2 = model.module.get_image_representation(inputs_u2, n=args.last_n_blocks, avgpool=False)
        outputs_u3 = model.module.get_image_representation(inputs_u3, n=args.last_n_blocks, avgpool=False)
        if not args.mixmatch_mse:
            p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2,
                                                                 dim=1) + torch.softmax(outputs_u3, dim=1)) / 3  # qb (before qb.cpu called)
            pt = p ** (1 / args.T)  # temp2 (just written in a different way)
            targets_u = pt / pt.sum(dim=1, keepdim=True)  # qb after total
        else:
            targets_u = (outputs_u + outputs_u2 + outputs_u3) / 3
        targets_u = targets_u.detach()  # removes tracking of gradients for this
    all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2, inputs_u3], dim=0)
    all_targets = torch.cat([targets_x, targets_u, targets_u, targets_u], dim=0)
    l = np.random.beta(args.alpha, args.alpha)
    l = max(l, 1 - l)

    idx = torch.randperm(all_inputs.size(0))

    input_a, input_b = all_inputs, all_inputs[idx]
    target_a, target_b = all_targets, all_targets[idx]

    mixed_input = l * input_a + (1 - l) * input_b
    mixed_target = l * target_a + (1 - l) * target_b

    # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
    mixed_input = list(torch.split(mixed_input, batch_size))
    mixed_input = interleave(mixed_input, batch_size)

    logits = [model.module.get_image_representation(mixed_input[0], n=args.last_n_blocks, avgpool=False)]
    for input in mixed_input[1:]:
        logits.append(model.module.get_image_representation(input, n=args.last_n_blocks, avgpool=False))

    logits = interleave(logits, batch_size)
    logits_x = logits[0]
    logits_u = torch.cat(logits[1:], dim=0)

    if not args.mixmatch_mse:
        loss_function = SemiLoss()
    else:
        loss_function = SemiMSELoss()
    Lx, Lu, w = loss_function(logits_x, mixed_target[:batch_size], logits_u,
                              mixed_target[batch_size:], current_epoch
                              , args)
    loss = Lx + w * Lu
    return loss
