import numpy as np
import pickle
import torch
import utils
from torch.nn.functional import relu
from torchvision import datasets
from metric_suite.metrics import Metrics
from tqdm import tqdm

import datetime
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--use_wandb", action='store_true', default=False)
parser.add_argument("--network", type=str, default="150x2x50")
parser.add_argument("--network_type", type=str, default="MLP")
parser.add_argument("--use_cpu", action='store_true')
parser.add_argument("--dataset", type=str, default="news")
parser.add_argument("--distance_weight", type=float, default=0.01, help="weight given to minimizing distance metric in CFX construction")
parser.add_argument("--stepsize", type=float, default=1e-1, help="Stepsize for CFX computation")
parser.add_argument("--global_iter_cap", type=int, default=100, help="Maximum iterations per CFX")
parser.add_argument("--num_iterations", type=int, default=600, help="iterations per CFX inner loop")
parser.add_argument("--l0_budget", type=int, default=None, help="l0 Budget used for JSMA map")

args = parser.parse_args()

if args.distance_weight == 0.0:
    args.global_iter_cap = 1

use_gpu = not args.use_cpu
sample_num = 500
rob_std_dev = 1
x_pert_budget = 5/100

if args.network_type == 'MLP':
    sample_num = 1

dist_weights_used = []

if use_gpu:
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

def load_model(network, net_type: str):
    assert net_type in ["MLP", "BNN"], "Invalid network type"
    if net_type == 'BNN':
        if network.startswith('bnn') or network.startswith("hybrid") or network.startswith("mlp_init_bnn"):
            with open("Networks/BNNS/{}/".format(args.dataset) + network + ".h5", 'rb') as f:
                means = pickle.load(f)
            for mean in means:
                mean[0], mean[1] = torch.tensor(mean[0][..., :100]), torch.tensor(mean[1][..., :100])
            devs = None
        else:
            with open("Networks/BNNS/{}/".format(args.dataset) + network +".h5", 'rb') as f:
                means, devs = pickle.load(f)
            means, devs = utils.conv_net_to_torch(means, devs)
        return means, devs
    else:
        with open("Networks/MLPS/{}/".format(args.dataset) + network +".h5", 'rb') as f:
            means = pickle.load(f)
        if len(network.split("x")) == 3:
            for i in range(len(means)):
                means[i][0] = means[i][0].permute(1, 0, 2).unsqueeze(-2)
                means[i][1] = means[i][1].unsqueeze(-2).unsqueeze(-2)
            return means, None
        devs = []
        for i in range(len(means)):
            means[i][0] = means[i][0].T
            devs.append([torch.zeros_like(means[i][0]), torch.zeros_like(means[i][1])])
        means, devs = utils.conv_net_to_torch(means, devs)
        return means, devs

def sample_net(means, std_devs):
    samples = []
    for layer in range(len(means)):
        sub_samp = []
        for sub_layer in range(len(means[layer])):
            sub_samp.append(torch.normal(means[layer][sub_layer].unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, sample_num), std_devs[layer][sub_layer].unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, sample_num)).to(device))
        samples.append(sub_samp)
    return samples

class Model(torch.nn.Module):
    def __init__(self, params):
        self.params = params
        self.layer_num = len(params)
        self.layer_sizes = []
        self.weight_batch_size = self.params[0][1].shape[-1]
        for i in range(self.layer_num):
            self.layer_sizes.append(self.params[i][1].shape[0])

    def _forward_op(self, x):
        '''Expects x as dim (input, 1, batch_size, 1 (for weight samples))'''
        for i in range(len(self.layer_sizes)-1):
            x = torch.sum(x * self.params[i][0], dim=0)[:, None, :] + self.params[i][1]
            x = relu(x)
        x = torch.sum(x * self.params[-1][0], dim=0)[:, None, :] + self.params[-1][1]
        if args.use_multiclass:
            return torch.mean(x, dim=-1)
        else:
            return torch.sigmoid(torch.mean(x, dim=-1))

    def forward(self, x):
        x = x.to(device)
        if len(x.shape) > 1:
            x = x.flatten().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, self.weight_batch_size)
        else:
            x = x.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, self.weight_batch_size)
        assert x.shape[0] == self.params[0][0].shape[0], "Shape mismatch between input and first weight matrix"
        return self._forward_op(x)


    def get_clean_acc(self, test_set, test_labels):
        correct = 0
        for i in (pbar := tqdm(range(len(test_set)))):
            pbar.set_description("Running test set")
            output = self.forward(test_set[i])
            if args.use_multiclass:
                if torch.argmax(output).item() == test_labels[i].item():
                    correct += 1
            else:
                if round(output.item()) == test_labels[i].item():
                    correct += 1
        return correct / len(test_set)


    def check_validity(self, pred, target):
        '''Returns Bool, if valid CF then True, else False'''
        if not pred.size() == 1:
            if args.use_multiclass:
                return np.argmax(pred.cpu()) == target
            else:
                return round(pred.item()) == target
        else:
            return pred == target


    def get_cfx(self, x, true_class, target_class, stepsize=1e-1, dist_weight = 0.5, num_iterations=1000):
        assert true_class != target_class, "The target and the true class cannot be the same."
        x = x.to(device)
        x_cf = x.clone().detach().type(torch.float64).to(device)
        counter = 0
        while not self.check_validity(self.forward(x_cf), target_class) and counter < args.global_iter_cap:
            x_cf = x.clone().detach().type(torch.float64).to(device)
            counter += 1
            x_cf = utils.attack(self, x, x_cf, target_class, true_class, num_iterations=num_iterations, stepsize=stepsize, dist_weight=dist_weight, use_multiclass=args.use_multiclass, l0_budget=args.l0_budget)
            dist_weight *= 0.97
        if not self.check_validity(self.forward(x_cf), target_class):
            return None
        dist_weights_used.append(dist_weight)
        return x_cf.cpu()


    def get_local_lipschitz(self, x):
        x = x.to(device)
        if len(x.shape) > 1:
            x = x.flatten().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, self.weight_batch_size)
        else:
            x = x.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, self.weight_batch_size)
        assert x.shape[0] == self.params[0][0].shape[0], "Shape mismatch between input and first weight matrix"
        activation_list = []
        for i in range(len(self.layer_sizes) - 1):
            x = torch.sum(x * self.params[i][0], dim=0)[:, None, :] + self.params[i][1]
            activation_list.append(torch.mean(x, dim=-1)>0)
            x = relu(x)
        local_lips = 1
        for i in range(len(self.params)-1):
            local_lips *= utils.spectral_norm(torch.mean(self.params[i][0], dim=-1)[..., None] * activation_list[i][None, ...])
        local_lips *= utils.spectral_norm(
            torch.mean(self.params[-1][0], dim=-1))
        return local_lips


def main():
    print("running CFX finder for Network: {0} of type: {1} with WANDB: {2} and DATASET: {3}".format(args.network, args.network_type, args.use_wandb, args.dataset))
    if not args.network.startswith('bnn') and len(args.network.split("x")) < 3 and not args.network.startswith("hybrid") and not args.network.startswith("mlp_init_bnn"):
        means, devs = load_model(args.network, args.network_type)
        samples = sample_net(means, devs)
    else:
        samples, _ = load_model(args.network, args.network_type)
        for sample in samples:
            sample[0], sample[1] = sample[0].to(device), sample[1].to(device)

    assert args.dataset in ["MNIST", "GERMAN", "DIABETES", 'news', 'spam'], "Invalid dataset"
    if args.dataset == "MNIST":
        mnist_deset = datasets.MNIST(root='../data', train=False, download=True, transform=None)
        dset = datasets.MNIST(root='../data', train=True, download=True, transform=None)
        x_train = dset.data/255.
        y_train = dset.targets
        x_test, y_test = mnist_deset.data/255., mnist_deset.targets
        data_dim = 784
        num_classes = 10
        args.use_multiclass = True
        with open("MNIST_targets.h5".format(args.dataset), 'rb') as f:
            targets = pickle.load(f)
    elif args.dataset == "GERMAN":
        with open("data/german_test.h5", 'rb') as f:
            x_test, y_test = pickle.load(f)
        with open("data/german_train.h5", 'rb') as f:
            x_train, y_train = pickle.load(f)
        x_test = torch.tensor(x_test)
        y_test = torch.tensor(y_test)
        x_train = torch.tensor(x_train)
        y_train = torch.tensor(y_train)
        data_dim = 20
        num_classes = 2
        args.use_multiclass = False
    elif args.dataset == "DIABETES":
        with open("data/diabetes_test.h5", 'rb') as f:
            x_test, y_test = pickle.load(f)
        with open("data/diabetes_train.h5", 'rb') as f:
            x_train, y_train = pickle.load(f)
        y_test = y_test.type(torch.int16)
        y_train = y_train.type(torch.int16)
        data_dim = 8
        num_classes = 2
        args.use_multiclass = False

    elif args.dataset == 'news':
        data_dim = 58
        num_classes = 2
        with open("data/news_train.h5", 'rb') as f:
            x_train, y_train = pickle.load(f)
        with open("data/news_test.h5", 'rb') as f:
            x_test, y_test = pickle.load(f)
        y_test = y_test.type(torch.int16)
        y_train = y_train.type(torch.int16)
        args.use_multiclass = False
    elif args.dataset == 'spam':
        data_dim = 57
        num_classes = 2
        with open("data/spam_train.h5", 'rb') as f:
            x_train, y_train = pickle.load(f)
        with open("data/spam_test.h5", 'rb') as f:
            x_test, y_test = pickle.load(f)
        y_test = y_test.type(torch.int16)
        y_train = y_train.type(torch.int16)
        args.use_multiclass = False

    model = Model(samples)

    clean_acc = model.get_clean_acc(x_test.view(-1, data_dim), y_test)
    print("Model clean accuracy: ", clean_acc*100, "%")

    metric = Metrics((x_train.view(-1, data_dim), y_train), device=device)

    data_range = (torch.max(x_test, dim=0).values - torch.min(x_test, dim=0).values)
    model.data_range = data_range.clone()
    budget = data_range * x_pert_budget

    '''Area to build new metrics'''

    '''Area end'''
    l1_costs = []
    l2_costs = []
    linf_costs = []
    lofs = []
    rob_scores = []
    our_rob_scores = [0]
    input_var_distance = []
    input_var_ratio = []
    implausibility = []
    unfaithfulness = []
    lof_predict = []
    failed_count = 0
    model.sgdl_samps = metric.get_SGDL_samples(model, target_range=y_train.unique(), number_of_samps=10)
    for i in (pbar := tqdm(range(50))):
        pbar.set_description("Generating Counterfactuals")
        x = x_test[i]
        true_class = y_test[i]
        if args.use_multiclass and true_class != model.forward(x).argmax().item():
            continue
        elif not args.use_multiclass and round(model.forward(x).item()) != true_class:
            continue
        if args.dataset == "MNIST":
            target = targets[i]
        else:
            target = np.random.randint(num_classes)
        while target == true_class:
            target = np.random.randint(num_classes)
        cfx = model.get_cfx(x, true_class, target, dist_weight=args.distance_weight, stepsize=args.stepsize, num_iterations=args.num_iterations)
        if cfx == None:
            failed_count += 1
            continue

        input_var_distances = []
        for pert_iter in range(50):
            t = torch.rand(x.shape) * 2 - 1
            noise = budget * t
            x_pert = x + noise
            cfx_pert = (model.get_cfx(x_pert, true_class, target, dist_weight=args.distance_weight, stepsize=args.stepsize, num_iterations=args.num_iterations))
            if cfx_pert == None:
                input_var_distances.append(metric.get_cost(cfx.flatten(), x.flatten(), metric='l2'))
            else:
                input_var_distances.append(metric.get_cost(cfx.flatten(), cfx_pert.flatten(), metric='l2'))
        l2_costs.append(metric.get_cost(cfx.flatten(), x.flatten(), metric='l2'))
        input_var_ratio.append(torch.stack(input_var_distances).mean()/l2_costs[-1])
        implausibility.append(metric.implausibility_soa_metric(cfx.flatten(), target))
        lof_predict.append(metric.lof.predict(cfx.flatten().unsqueeze(0)))

    results = {}
    results['robustness_ratio'] = np.mean(input_var_ratio)
    results['CFX_success_fraction'] = (50 - failed_count) / 50
    results['average_implausibility'] = np.mean(implausibility)
    results['average_lof_predict'] = np.mean(lof_predict)
    results['l2_cost'] = np.mean(l2_costs)

    with open('{0}_{1}.txt'.format(args.dataset, args.network), 'w') as file:
        for key, value in results.items():
            file.write(f'{key}: {value}\n')

if __name__ == '__main__':
    main()
