import torch
import sys
import torch.nn.functional as F
sys.path.append('..')
from utils import Averager, clip_perturbed_image
from loss import *


def clean_sample_selection(model, test_loader, C, rho, tau, device):
    model.eval()

    clean_sets = {i:[] for i in range(C)}
    for x, _, idx in test_loader:
        x, idx = x.to(device), idx.to(device)
        output = model(x)
        prob = F.softmax(output, dim=-1)
        conf, pred = torch.max(prob, dim=-1)
        conf_high = conf > tau
        for i in torch.unique(pred):
            pred_i = pred == i
            clean_idx_i = torch.logical_and(conf_high, pred_i)
            if not (clean_idx_i == False).all():
                clean_sets[int(i)].append(torch.hstack([conf[clean_idx_i].unsqueeze(-1), idx[clean_idx_i].unsqueeze(-1),
                                                        pred[clean_idx_i].type(torch.int).unsqueeze(-1)]))

    k = len(test_loader.dataset) * rho / C
    clean_dataset = []
    for i in clean_sets.keys():
        if len(clean_sets[i]) > 0:
            clean_sets[i] = torch.vstack(clean_sets[i])
            if len(clean_sets[i]) > k:
                _, retain_set_i = torch.topk(clean_sets[i][:, 0], k=int(k))
                clean_sets[i] = clean_sets[i][retain_set_i]
            clean_dataset.append(clean_sets[i])
    assert len(clean_dataset) > 0
    clean_dataset = torch.vstack(clean_dataset)

    return clean_dataset


def zoo(x, y, model, adapt, args, device, clean=True):
    ce_loss = nn.CrossEntropyLoss(size_average=None, reduce=False, reduction='none')
    # generate perturbation
    delta = adapt(x)
    x_tilda = x + args.ad_scale * delta
    x_tilda = clip_perturbed_image(x, x_tilda)

    # delta norm regularizer
    delta_norm = torch.linalg.norm(delta, ord=1, dim=(-2, -1)).mean()

    loss_tmps = []
    if not args.zo:
        logits_tilda = model(x_tilda)
        if clean:
            loss_tr = ce_loss(logits_tilda, y)
        else:
            loss_tr = im_loss(logits_tilda)
    else:
        # ZO gradient estimate
        batch_size = x_tilda.size()[0]
        channel = x_tilda.size()[1]
        h = x_tilda.size()[2]
        w = x_tilda.size()[3]
        x_temp = x_tilda.detach()

        with torch.no_grad():
            mu = torch.tensor(args.mu).to(device)
            q = torch.tensor(args.q).to(device)

            # Forward Inference (Original)
            recon_pre = model(x_temp)
            if clean:
                loss_0 = ce_loss(recon_pre, y)
            else:
                loss_0 = im_loss(recon_pre, reduce=False)

            # ZO Gradient Estimation
            grad_est = torch.zeros_like(x_temp).to(device)
            for k in range(args.q):
                # Obtain a random direction vector
                u = torch.normal(0, args.sigma, size=(batch_size, channel, h, w)).to(device)
                u /= torch.sqrt(torch.sum(u ** 2, dim=(1, 2, 3))).reshape(batch_size, 1, 1, 1).expand(batch_size,
                                                                                                      channel, h, w)

                # Forward Inference (reconstructed image + random direction vector)
                recon_q_pre = model(x_temp + mu * u)

                # Loss Calculation and Gradient Estimation
                if clean:
                    loss_tmp = ce_loss(recon_q_pre, y)
                else:
                    loss_tmp = im_loss(recon_q_pre, reduce=False)
                loss_diff = torch.tensor(loss_tmp - loss_0)
                grad_est = grad_est + u * loss_diff.reshape(batch_size, 1, 1, 1) / (mu * q)
                loss_tmps.append(loss_tmp.detach().cpu().mean())

            # reconstructed image * gradient estimation   <--   g(x) * a
            loss_tr = torch.sum(x_tilda * grad_est, dim=(1, 2, 3)).mean()

    return loss_tr, delta_norm, loss_tmps


def offline_tta(args, model, adapt, train_loader, test_loader, device=None):
    # black-box model
    model.eval()
    for param in model.parameters():
        param.requires_grad_(False)

    # optimizer
    parameters = adapt.parameters()
    if args.optim == 'SGD':
        optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=0.9, weight_decay=1e-5)
    else:
        optimizer = torch.optim.Adam(parameters, args.lr, weight_decay=1e-5)

    # torch.autograd.set_detect_anomaly(True)
    accs = []
    with torch.no_grad():
        clean_sets = clean_sample_selection(model, train_loader, model.out_dim, args.rho, args.tau, device)
    print(f"number of clean data selected: {clean_sets.shape[0]}")

    for epoch in range(0, args.steps + 1):
        avg_clean = Averager()
        avg_noisy = Averager()
        avg_norm = Averager()
        if epoch > 0:

            for i_bat, data_bat in enumerate(train_loader):
                x, y, idx = (data_bat[0].to(device), data_bat[1].to(device), data_bat[2].to(device))
                clean_mask = torch.isin(idx, clean_sets[:, 1])
                clean_set_mask = torch.isin(clean_sets[:, 1], idx)
                clean_x = x[clean_mask]
                noisy_x = x[~clean_mask]
                clean_y = clean_sets[clean_set_mask, 2].type(torch.long)

                optimizer.zero_grad()

                adapt.train()
                model.eval()

                if clean_x.shape[0] != 0:
                    loss_clean, norm_clean, loss_clean_tmps = zoo(clean_x, clean_y, model, adapt, args, device, clean=True)
                else:
                    loss_clean, norm_clean, loss_clean_tmps = torch.tensor(0.0).to(device), torch.tensor(0.0).to(device), [0.0]
                loss_noisy, norm_noisy, loss_noisy_tmps = zoo(noisy_x, None, model, adapt, args, device, clean=False)
                loss_norm = (norm_clean * clean_x.shape[0] + norm_noisy * (noisy_x.shape[0])) / x.shape[0]

                loss = loss_noisy + args.wclean * loss_clean + args.wdelta * loss_norm
                if args.zo:
                    loss_record_clean = torch.mean(torch.tensor(loss_clean_tmps))
                    loss_record_noisy = torch.mean(torch.tensor(loss_noisy_tmps))
                else:
                    loss_record_clean = loss_clean
                    loss_record_noisy = loss_noisy
                # print(f"epoch {epoch}: loss_record {loss_record:.4f}, loss {loss:.4f},"
                #       f" loss_tr {loss_tr:.4f}, delta_norm {delta_norm:.4f}", flush=True)

                loss.backward()
                optimizer.step()
                avg_clean.update(loss_record_clean)
                avg_noisy.update(loss_record_noisy)
                avg_norm.update(loss_norm)

        if epoch % args.eval_interval == 0:
            avgr = Averager()
            adapt.eval()
            model.eval()
            with torch.no_grad():
                for x, y, _ in train_loader:
                    x, y = x.to(device), y.to(device)
                    if epoch > 0:
                        delta = adapt(x)
                        x_tilda = x + args.ad_scale * delta
                        x_tilda = clip_perturbed_image(x, x_tilda)
                    else:
                        x_tilda = x
                    logits = model(x_tilda)
                    ypred = logits.argmax(dim=-1)
                    avgr.update((ypred == y).float().mean().item(), nrep=len(y))
            acc = avgr.avg
            accs.append(acc)
            print(f"epoch {epoch:.1f}, loss_clean = {avg_clean.avg}, acc = {acc:3f}, loss_noisy = {avg_noisy.avg},"
                  f" delta norm = {avg_norm.avg:3f}.", flush=True)

    return accs