import os
import torch
import numpy as np
from tqdm import tqdm
from utils.utils_models import load_detectors
from scipy.optimize import minimize, Bounds

def compute_probs_detector(intermediate_layer_nat, intermediate_layer_adv, detector):
    detected_nat = torch.sigmoid(detector(intermediate_layer_nat))
    detected_adv = torch.sigmoid(detector(intermediate_layer_adv))

    return np.concatenate((detected_nat.detach().cpu().numpy(), detected_adv.detach().cpu().numpy()))

def mutual_conditional_information(w, prob_adv_x):
    prob_nat_x = 1. - prob_adv_x + 1e-20

    adv = 0.
    nat = 0.
    for k in range(prob_adv_x.shape[0]):
        adv += w[k] * prob_adv_x[k] * np.log(w[k] * prob_adv_x[k] / (w[k] * np.sum(prob_adv_x * w)))
        nat += w[k] * prob_nat_x[k] * np.log(w[k] * prob_nat_x[k] / (w[k] * np.sum(prob_nat_x * w)))

    return -(adv + nat)

def optimize(prob_losses_adv):
    # Optimization Eq. (10)

    bounds = Bounds(1e-20, 1.)

    constraint = ({'type': 'eq', 'fun': lambda w: np.sum(w) - 1})
    mutual_informations = np.zeros(prob_losses_adv.shape)

    print("Optimization")
    for i in tqdm(range(prob_losses_adv.shape[1]), ncols=70, ascii=True, colour='blue'):
        opt = minimize(fun=mutual_conditional_information, x0=np.ones(prob_losses_adv.shape[0], ) / float(prob_losses_adv.shape[0]), args=prob_losses_adv[:, i] + 1e-20, constraints=constraint,
                       bounds=bounds, options={'maxiter': 1000})
        mutual_informations[:, i] = np.asarray(opt.x).T

    final_prob_adv = np.sum(mutual_informations * prob_losses_adv, axis=0)
    final_prob = np.expand_dims(final_prob_adv, axis=1)

    return final_prob

def agree_output(args, attack, classifier, device, intermediate_layer_nat, intermediate_layer_adv, n_samples_adv):

    # ---- Load model detectors ----
    losses = ['CE', 'Rao', 'KL', 'g']
    detectors = [load_detectors(args, model=classifier, device=device, loss=losses[i], epsilon=args.TRAIN.PGDi.epsilon)[0][-1].eval()
                 for i in range(len(losses))]

    # ---- Compute values wrt the detectors ----
    prob_losses_adv = np.zeros((len(detectors), n_samples_adv * 2))

    for i in range(prob_losses_adv.shape[0]):
        prob_losses_adv[i, :] = np.squeeze(compute_probs_detector(intermediate_layer_nat, intermediate_layer_adv, detectors[i]))

    final_prob = optimize(prob_losses_adv)

    prob_natural = final_prob[: len(final_prob) // 2]
    prob_adversarial = final_prob[len(final_prob) // 2:]

    print('Naturals :', np.mean(prob_natural), np.std(prob_natural))
    print('Adversarial :', np.mean(prob_adversarial), np.std(prob_adversarial))

    # ---- Save the probabilities ----
    path_res = '{}{}/all_{}/'.format(args.TEST.RESULTS.res_dir, args.DATA_NATURAL.data_name, args.TRAIN.PGDi.epsilon)
    print('{}/probs_{}_all.npy'.format(path_res, '{}{}'.format(args.DATA_NATURAL.data_name, attack)))
    os.makedirs(path_res, exist_ok=True)
    np.save('{}/probs_{}_all.npy'.format(path_res, '{}{}'.format(args.DATA_NATURAL.data_name, attack)), final_prob)

    return final_prob
