import os
import logging

import torch
from bistiming import Stopwatch
from mkdir_p import mkdir_p
import numpy as np
import joblib
from scipy.special import softmax

from .utils import set_random_seed
from spurious_ml.models.torch_bbox_attack_model import TorchBBoxAttackModel
from spurious_ml.variables import get_file_name
from spurious_ml.datasets import add_spurious_correlation


logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
                    level=logging.WARNING, datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)

base_model_dir = './models/mem_inference/'

def get_shadow_attack_model(auto_var, n_features, n_classes, n_channels):
    params = {}
    params['n_features'] = n_features
    params['n_classes'] = n_classes
    params['n_channels'] = n_channels

    params['loss_name'] = "ce"
    params['architecture'] = "shadow_attack_model"
    params['multigpu'] = False
    params['dataaug'] = None

    params['learning_rate'] = auto_var.get_var("learning_rate")
    params['epochs'] = auto_var.get_var("epochs")
    params['momentum'] = auto_var.get_var("momentum")
    params['optimizer'] = auto_var.get_var("optimizer")
    params['batch_size'] = auto_var.get_var("batch_size")
    params['weight_decay'] = auto_var.get_var("weight_decay")
    params['grad_clip'] = auto_var.get_var("grad_clip")
    params['noise_multiplier'] = auto_var.get_var("noise_multiplier")

    return TorchBBoxAttackModel(**params)

def get_meminf_result_path(auto_var):
    batch_size = auto_var.get_variable_name("batch_size")
    ds_name = auto_var.get_variable_name("dataset")
    epochs = 70
    model_name = auto_var.get_variable_name("model")
    momentum = auto_var.get_variable_name("momentum")
    optimizer = auto_var.get_variable_name("optimizer")
    random_seed = auto_var.get_variable_name("random_seed")
    weight_decay = auto_var.get_variable_name("weight_decay")

    meminf_result_path = f"{batch_size}-{ds_name}-{epochs}-{model_name}-{momentum}-{optimizer}-{random_seed}-{weight_decay}.pkl"
    meminf_result_path = os.path.join("./results/mem_inference", meminf_result_path)
    return meminf_result_path

def get_attack_preds(model, predX, y):
    predX = softmax(predX, axis=1)
    preds = (predX.argmax(1) == y).reshape(-1, 1).astype(np.float32)
    predX = np.sort(predX, axis=-1)
    ret = model.predict_real(predX, preds)
    return ret

def run_bbox_inference(auto_var):
    _ = set_random_seed(auto_var)

    (tar_trnX, tar_trny, tar_tstX, tar_tsty, shadow_trnX, shadow_trny,
     shadow_tstX, shadow_tsty, spurious_ind) = auto_var.get_var("dataset")

    mkdir_p(base_model_dir)
    result = {'spurious_ind': spurious_ind}
    result['meminf_result_path'] = get_meminf_result_path(auto_var)
    res = joblib.load(result['meminf_result_path'])

    att_trnX = softmax(np.concatenate((res['aux_shadow_trn_pred'], res['aux_shadow_tst_pred']), axis=0), axis=1)
    att_trnpred = (att_trnX.argmax(1) == np.concatenate((shadow_trny, shadow_tsty))).reshape(-1, 1).astype(np.float32)
    att_trnX = np.sort(att_trnX, axis=-1)
    att_tstX = softmax(np.concatenate((res['target_tar_trn_pred'], res['target_tar_tst_pred']), axis=0), axis=1)
    att_tstpred = (att_tstX.argmax(1) == np.concatenate((tar_trny, tar_tsty))).reshape(-1, 1).astype(np.float32)
    att_tstX = np.sort(att_tstX, axis=-1)
    att_trny = np.concatenate((np.ones(len(shadow_trnX)), np.zeros(len(shadow_tstX))), axis=0)
    att_tsty = np.concatenate((np.ones(len(tar_trnX)), np.zeros(len(tar_tstX))), axis=0)

    result['shadow_attack_model_path'] = os.path.join(
            base_model_dir, get_file_name(auto_var) + "_shawdowatt.pt")
    attack_model = get_shadow_attack_model(auto_var, shadow_trnX.shape[1:], len(np.unique(shadow_trny)), shadow_trnX.shape[-1])
    attack_model.tst_ds = (att_tstX, att_tstpred, att_tsty)
    if os.path.exists(result['shadow_attack_model_path']):
        attack_model.load(result['shadow_attack_model_path'])
    else:
        with Stopwatch("Train Attack Model", logger=logger):
            history = attack_model.fit(att_trnX, att_trnpred, att_trny)
        attack_model.save(result['shadow_attack_model_path'])
        result['history'] = history
    result['attack_tar_trn_pred'] = get_attack_preds(attack_model, res['target_tar_trn_pred'], tar_trny)
    result['attack_tar_tst_pred'] = get_attack_preds(attack_model, res['target_tar_tst_pred'], tar_tsty)
    result['attack_shadow_trn_pred'] = get_attack_preds(attack_model, res['aux_shadow_trn_pred'], shadow_trny)
    result['attack_shadow_tst_pred'] = get_attack_preds(attack_model, res['aux_shadow_tst_pred'], shadow_tsty)
    result['attack_mod_tar_trn_pred'] = get_attack_preds(attack_model, res['target_mod_tar_trn_pred'], tar_trny)
    result['attack_mod_tar_tst_pred'] = get_attack_preds(attack_model, res['target_mod_tar_tst_pred'], tar_tsty)
    result['attack_mod_shadow_trn_pred'] = get_attack_preds(attack_model, res['aux_mod_shadow_trn_pred'], shadow_trny)
    result['attack_mod_shadow_tst_pred'] = get_attack_preds(attack_model, res['aux_mod_shadow_tst_pred'], shadow_tsty)

    return result
