import copy
import os
import random

import numpy as np
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader

import ramps
from hmix_args_c import device, process_args
from hmix_image_dataset import load_image_dataset
from hmix_model import MixCNNCIFAR, MixCNNSTL, MixLeNet

args = process_args()


def seed_torch(seed=args.seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # torch.backends.cudnn.enabled = True
    # torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True


seed_torch()


def select_model(model_name):
    models = {
        "cnn": MixCNNCIFAR,
        "cnnstl": MixCNNSTL,
        "lenet": MixLeNet,
    }
    return models[model_name]


def make_optimizer(model, optim, stepsize):
    if optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=stepsize,
                                     betas=(0.5, 0.99),
                                     weight_decay=args.weight_decay)
    elif optim == "adagrad":
        optimizer = torch.optim.Adagrad(model.parameters(),
                                        lr=stepsize,
                                        weight_decay=args.weight_decay)
    elif optim == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=stepsize,
                                      weight_decay=args.weight_decay)
    elif optim == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=stepsize,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif optim == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         lr=stepsize,
                                         rho=0.95)
    return optimizer


def make_scheduler(optim):
    # scheduler = MultiStepLR(optim, milestones=[50, 100], gamma=0.1)
    scheduler = MultiStepLR(optim,
                            milestones=args.milestones,
                            gamma=args.scheduler_gamma)

    return scheduler


def create_model(selected_model, dim=None, ema=False):
    model = selected_model(dim)
    if ema:
        for param in model.parameters():
            param.detach_()
    return model


def update_ema_variables(model, ema_model, alpha, epoch):
    # Use the true average until the exponential average is more correct
    if args.ema_update:
        alpha = ramps.sigmoid_rampup(epoch + 1, args.ema_end) * alpha
        # alpha = alpha * ramps.linear_rampup(epoch, args.ema_end - args.ema_start)
    else:
        alpha = min(1 - 1 / (epoch - args.ema_start + 1), alpha)
    print(alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_((param), alpha=1. - alpha)


def check_mean_teacher(epoch):
    if not args.mean_teacher:
        return False
    elif epoch < args.ema_start:
        return False
    else:
        return True


class trainer():
    def __init__(self, models, ema_models, optimizers, schedulers,
                 XYPtrainLoader, XYUtrainLoader, XYvalidLoader, XYtestLoader,
                 XYPtrainSet):
        self.models = models
        self.ema_models = ema_models.values()
        self.optimizers = optimizers.values()
        self.schedulers = schedulers.values()

        self.XYPtrainLoader = XYPtrainLoader
        self.XYUtrainLoader = XYUtrainLoader
        self.XYvalidLoader = XYvalidLoader
        self.XYtestLoader = XYtestLoader
        self.XYPtrainSet = DataLoader(XYPtrainSet,
                                      batch_size=1000,
                                      shuffle=False)

        self.h_inputs_x = torch.Tensor([]).to(device)
        self.h_features_x = torch.Tensor([]).to(device)
        self.h_entropys_x = torch.Tensor([]).to(device)
        self.h_preds_x = torch.Tensor([]).to(device)

    def get_hinput(self, features_ori):
        h_input_b_idx = []
        n_unlabeled = features_ori.size(0)
        h_input_b_idx = np.random.randint(args.h_positive, size=n_unlabeled)

        return torch.tensor(h_input_b_idx, device=device, dtype=torch.int64)

    def update_hinput(self, net):
        """High entropy, for binary classification,
        the prediction probability is near 0.5"""
        with torch.no_grad():
            h_inputs_x_ = []
            h_features_x_ = []
            h_entropys_x_ = []
            h_preds_x_ = []
            for data, _ in self.XYPtrainSet:
                data = data.to(device, non_blocking=True)
                outputs_x, features_x = net(data, flag_feature=True)
                preds_x = torch.sigmoid(outputs_x)
                entropys_x = -(preds_x * F.logsigmoid(outputs_x) +
                               (1. - preds_x) * F.logsigmoid(-outputs_x))
                preds_x = preds_x.view(data.size(0))
                entropys_x = entropys_x.view(data.size(0))

                h_inputs_x_.extend(np.array(data.cpu()))
                h_features_x_.extend(np.array(features_x.cpu()))
                h_entropys_x_.extend(np.array(entropys_x.cpu()))
                h_preds_x_.extend(np.array(preds_x.cpu()))

        h_inputs_x_ = np.array(h_inputs_x_)
        h_features_x_ = np.array(h_features_x_)
        h_entropys_x_ = np.array(h_entropys_x_)
        h_preds_x_ = np.array(h_preds_x_)

        h_group_x = list(
            zip(h_inputs_x_, h_features_x_, h_entropys_x_, h_preds_x_))
        h_group_x.sort(key=lambda x: x[2], reverse=True)

        sort_h_inputs_x_c = [x[0] for x in h_group_x[:args.h_positive]]
        sort_h_features_x_c = [x[1] for x in h_group_x[:args.h_positive]]
        sort_h_entropys_x_c = [x[2] for x in h_group_x[:args.h_positive]]
        sort_h_preds_x_c = [x[3] for x in h_group_x[:args.h_positive]]

        self.h_inputs_x = torch.tensor(sort_h_inputs_x_c, device=device)
        self.h_features_x = torch.tensor(sort_h_features_x_c, device=device)
        self.h_entropys_x = torch.tensor(sort_h_entropys_x_c, device=device)
        self.h_preds_x = torch.tensor(sort_h_preds_x_c, device=device)

    def train(self, epoch):
        results_val = {}
        results_test = {}

        for model_name in self.models.keys():
            results_val[model_name] = []
            results_test[model_name] = []

        for net, ema_net, opt, scheduler, model_name in zip(
                self.models.values(), self.ema_models, self.optimizers,
                self.schedulers, self.models.keys()):
            XYUtrainLoader_iter = iter(self.XYUtrainLoader)
            XYPtrainLoader_iter = iter(self.XYPtrainLoader)
            net.train()
            ema_net.train()
            for i in range(args.val_iterations):
                # load positive data
                try:
                    data_p, _ = XYPtrainLoader_iter.next()
                except:
                    XYPtrainLoader_iter = iter(self.XYPtrainLoader)
                    data_p, _ = XYPtrainLoader_iter.next()

                # load unlabeled data
                try:
                    data_u, _ = XYUtrainLoader_iter.next()
                except:
                    XYUtrainLoader_iter = iter(self.XYUtrainLoader)
                    data_u, _ = XYUtrainLoader_iter.next()

                target_p = torch.ones(data_p.shape[0]).to(device,
                                                          non_blocking=True)
                target_u = torch.zeros(data_u.shape[0]).to(device,
                                                           non_blocking=True)
                target_p = target_p[:, None]
                target_u = target_u[:, None]
                target_p_ = torch.cat((1. - target_p, target_p), dim=1)
                target_u_ = torch.cat((1. - target_u, target_u), dim=1)
                data_p = data_p.to(device, non_blocking=True)
                data_u = data_u.to(device, non_blocking=True)

                targets_ = torch.cat((target_p_, target_u_), dim=0)
                data = torch.cat((data_p, data_u), dim=0)
                idx_p = slice(0, len(data_p))
                idx_u = slice(len(data_p), len(data))

                with torch.no_grad():
                    outputs = ema_net(data)
                    p = torch.sigmoid(outputs)
                    p = p.detach()
                    targets_elr = torch.cat([1. - p, p], dim=1)

                    if epoch >= args.start_hmix:
                        outputs_ori, features_ori = net(data_u,
                                                        flag_feature=True)

                        p_indicator_u = torch.sigmoid(outputs_ori).detach()
                        p_indicator_p = torch.ones(len(data_p),
                                                   dtype=torch.float32,
                                                   device=device)
                        p_indicator_p = p_indicator_p[:, None]
                        p_indicator = torch.cat((p_indicator_p, p_indicator_u),
                                                dim=0)
                        p_indicator = p_indicator.view(p_indicator.size(0))

                        #====================================================================
                        # correct unlabeled
                        target_u_fix = torch.zeros(data_u.shape[0]).to(device,
                                                                       non_blocking=True)
                        target_u_fix[p_indicator[idx_u] >= args.p_upper] = 1.
                        target_u_fix = target_u_fix[:, None]
                        target_u_fix_ = torch.cat((1. - target_u_fix, target_u_fix), dim=1)
                        targets_ = torch.cat((target_p_, target_u_fix_), dim=0)
                        #===================================================================

                        h_p = torch.ones(len(self.h_inputs_x),
                                         dtype=torch.float32,
                                         device=device)
                        h_p = h_p[:, None]
                        h_targets_elr = torch.cat([1. - h_p, h_p], dim=1)

                if epoch >= args.start_hmix:
                    # Heuristic mixup
                    h_input_b_idx = self.get_hinput(features_ori)
                    h_target_b = torch.ones(len(h_input_b_idx),
                                            dtype=torch.float32,
                                            device=device)
                    h_target_b = h_target_b[:, None]
                    h_target_b_ = torch.cat([1. - h_target_b, h_target_b],
                                            dim=1)
                    idx1 = torch.tensor(
                        np.random.randint(data.size(0), size=data_p.size(0)))
                    data_b1 = torch.cat(
                        [data[idx1], self.h_inputs_x[h_input_b_idx]], dim=0)
                    targets_b1 = torch.cat([targets_[idx1], h_target_b_],
                                           dim=0)
                    targets_elr_b1 = torch.cat(
                        [targets_elr[idx1], h_targets_elr[h_input_b_idx]])

                    # Randomly mixup
                    idx2 = torch.tensor(
                        np.random.randint(data.size(0), size=data_u.size(0)))
                    idx = torch.cat([idx1, idx2])
                    data_b, targets_b, targets_elr_b = data[idx], targets_[
                        idx], targets_elr[idx]

                    p_indicator[p_indicator >= args.p_upper] = 1.
                    p_indicator[p_indicator <= args.p_lower] = 1.
                    p_indicator[idx_p] = 1.
                    p_indicator[p_indicator != 1.] = 0.

                    data_b = (
                        p_indicator * data_b.swapdims(0, -1) +
                        (1. - p_indicator) * data_b1.swapdims(0, -1)).swapdims(
                            0, -1)
                    targets_b = (p_indicator * targets_b.swapdims(0, -1) +
                                 (1. - p_indicator) *
                                 targets_b1.swapdims(0, -1)).swapdims(0, -1)
                    targets_elr_b = (
                        p_indicator * targets_elr_b.swapdims(0, -1) +
                        (1. - p_indicator) *
                        targets_elr_b1.swapdims(0, -1)).swapdims(0, -1)
                else:
                    idx = torch.randperm(data.size(0))
                    data_b, targets_b, targets_elr_b = data[idx], targets_[
                        idx], targets_elr[idx]

                data_a, targets_a, targets_elr_a = data, targets_, targets_elr

                l = np.random.beta(args.alpha, args.alpha)
                l = max(l, 1. - l)

                mix_layer = args.mix_layer

                # get output for every net
                outputs = net(data_a, data_b, l, mix_layer)
                mix_targets = l * targets_a + (1. - l) * targets_b
                mix_targets_elr = l * targets_elr_a + (1. - l) * targets_elr_b

                logits = torch.sigmoid(outputs)
                logits_ = torch.cat([1. - logits, logits], dim=1)
                logits_ = torch.clamp(logits_, 1e-4, 1. - 1e-4)

                loss_p = -(mix_targets[idx_p] *
                           (logits_[idx_p]).log()).sum(1).mean()
                loss_p = loss_p * args.positive_weight

                loss_u = -(mix_targets[idx_u] *
                           (logits_[idx_u]).log()).sum(1).mean()
                loss_u = loss_u * args.unlabeled_weight

                loss_elr = ((
                    1. - (mix_targets_elr * logits_).sum(dim=1)).log()).mean()
                loss_elr = loss_elr * args.elr_weight

                loss_ent = -(logits_ *
                             logits_.log()).sum(1).mean() * args.entropy_weight

                loss = loss_p + loss_u + loss_ent + loss_elr

                print(
                    "model:{0}\tlossP:{1}\tlossU:{2}\tlossEnt:{3}\tlossElr:{4}"
                    .format(model_name, loss_p.item(), loss_u.item(),
                            loss_ent.item(), loss_elr.item()))
                opt.zero_grad()  # clear gradients for next train
                loss.backward()  # backpropagation, compute gradients
                opt.step()  # apply gradients

            if check_mean_teacher(epoch):
                # update parameters of ema_net
                update_ema_variables(net, ema_net, args.ema_decay, epoch)

            if epoch >= args.start_hmix - 1:
                self.update_hinput(net)

            scheduler.step()

            # results_val[model_name].extend(net.error(self.XYvalidLoader))
            results_test[model_name].extend(net.error(self.XYtestLoader))

        return results_val, results_test

    def run(self, Epochs):
        results_test = {}
        results_val = {}
        # [precision,recall,error_rate, num_predicted_positive]
        for model_name in self.models.keys():
            results_test[model_name] = [[], [], [], []]
            results_val[model_name] = [[], [], [], []]

        print("Epoch\t" + "".join([
            "{}/precision\t{}/recall\t{}/error\t{}/NPP\t".format(
                model_name, model_name, model_name, model_name)
            for model_name in sorted(self.models.keys())
        ]))

        for epoch in range(Epochs):
            _results_val, _results_test = self.train(epoch)
            print("{}\t".format(epoch) + "".join([
                "{:-8}\t{:-8}\t{:-8}\t{:-8}\t".format(
                    round(_results_test[model_name][0], 4),
                    round(_results_test[model_name][1], 4),
                    round(_results_test[model_name][2], 4),
                    round(_results_test[model_name][3], 4),
                ) for model_name in sorted(_results_test.keys())
            ]))

            for model_name in self.models.keys():
                results_test[model_name][0].append(
                    _results_test[model_name][0])
                results_test[model_name][1].append(
                    _results_test[model_name][1])
                results_test[model_name][2].append(
                    _results_test[model_name][2])
                results_test[model_name][3].append(
                    _results_test[model_name][3])

                # results_val[model_name][0].append(_results_val[model_name][0])
                # results_val[model_name][1].append(_results_val[model_name][1])
                # results_val[model_name][2].append(_results_val[model_name][2])
                # results_val[model_name][3].append(_results_val[model_name][3])

        return results_val, results_test


def main():
    # args = process_args()
    print("using:", device)

    image_datasets = ['mnist', 'fashionmnist', 'cifar10', 'stl10']

    dim = None

    # dataset setup
    (XYPtrainLoader, XYUtrainLoader, XYtrainLoader, XYvalidLoader,
     XYtestLoader, XYPtrainSet, dim,
     prior) = load_image_dataset(args.dataset, args.labeled, args.batchsize,
                                 args.positive_label_list)
    prior = args.prior

    # model setup
    loss_type = select_loss(args.loss)
    selected_model = select_model(args.model)
    model = create_model(selected_model, dim=dim)
    ema_model = create_model(selected_model, dim=dim, ema=True)
    models = {
        "p3mix_c": copy.deepcopy(model).to(device),
    }
    ema_models = {
        "p3mix_c": copy.deepcopy(ema_model).to(device),
    }

    # trainer setup
    optimizers = {
        k: make_optimizer(v, args.optim, args.stepsize)
        for k, v in models.items()
    }
    schedulers = {k: make_scheduler(v) for k, v in optimizers.items()}

    print("    batchsize: {}".format(args.batchsize))
    print("    model: {}".format(selected_model))
    print("    beta: {}".format(args.beta))
    print("    gamma: {}".format(args.gamma))
    print("    mix layer: {}".format(args.mix_layer))
    print("    alpha: {}".format(args.alpha))
    print("    unlabeled weight: {}".format(args.unlabeled_weight))
    print("    p_lower: {}".format(args.p_lower))
    print("    p_upper: {}".format(args.p_upper))
    print("    prior: {}".format(args.prior))
    print("")

    # run training
    PUtrainer = trainer(models, ema_models, optimizers, schedulers,
                        XYPtrainLoader, XYUtrainLoader, XYvalidLoader,
                        XYtestLoader, XYPtrainSet)
    results_pu_val, results_pu_test = PUtrainer.run(args.epoch)


if __name__ == '__main__':
    import os
    import sys
    os.chdir(sys.path[0])
    print("working dir: {}".format(os.getcwd()))
    main()
