import os
import math
from re import S
from matplotlib import legend
from matplotlib.legend import Legend

import numpy as np
import torch
from tqdm import tqdm
import logging
logger = logging.getLogger()
import math
from typing import *
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import pandas as pd
import matplotlib.pyplot as plt


class PartDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, lab_list):
        self.dataset = dataset
        self.used_ids = []
        self.lab_list = lab_list
        self.lab_map = dict()
        for i in range(len(lab_list)):
            self.lab_map[lab_list[i]] = i
        print("lab_map", self.lab_map)
        for i, (X, y) in enumerate(dataset):
            if y in self.lab_list:
                self.used_ids.append(i)

    def __len__(self, ):
        return len(self.used_ids)

    def __getitem__(self, i):
        X, y = self.dataset[self.used_ids[i]]
        y_new = self.lab_map[y]
        return X, y_new


class MnistNet(nn.Module):
    def __init__(self, num_of_classes=10):
        super(MnistNet, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, 8, 2, padding=3)
        self.conv2 = nn.Conv2d(16, 32, 4, 2)
        self.fc1 = nn.Linear(32 * 4 * 4, 32)
        self.fc2 = nn.Linear(32, num_of_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # -> [B, 16, 14, 14]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 16, 13, 13]
        x = F.relu(self.conv2(x))  # -> [B, 32, 5, 5]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 32, 4, 4]
        x = x.view(-1, 32 * 4 * 4)  # -> [B, 512]
        x = F.relu(self.fc1(x))  # -> [B, 32]
        x = self.fc2(x)  # -> [B, 10]

        return x


def certificate_over_dataset(model, dataloader, PREFIX, N_m):
    model_preds = []
    labs = []
    sm = nn.Softmax(dim=1)

    for _ in tqdm(range(N_m)):

        model.load_state_dict(torch.load(PREFIX + '%d' % _)['state_dict'])

        all_pred = np.zeros((0, 2))
        for x_in, y_in in dataloader:
            x_in = x_in.cuda()
            output = model(x_in).squeeze(
                1)  # result after linear layer # [bs, 2]
            pred = sm(output).detach().cpu().numpy()  # [bs, 2]

            if (_ == 0):
                labs = labs + list(y_in.numpy())
            all_pred = np.concatenate([all_pred, pred], axis=0)

        model_preds.append(
            all_pred
        )  # all_pred: [num_samples, 2] ;  # model_preds: [num_models]

    gx = np.array(model_preds).mean(0)  # mean for all models
    labs = np.array(labs)  # [num_samples]

    pa = gx.max(1)
    pred_c = gx.argmax(1)  # [num_samples]

    gx[np.arange(len(pred_c)), pred_c] = -1
    pb = gx.max(1)  # [num_samples]

    is_acc = (pred_c == labs)

    return pa, pb, is_acc


def get_dp_result(folder_prefix, saved_model_name):
    filename = folder_prefix + saved_model_name + '/all_exp.csv'
    print(saved_model_name)
    df = pd.read_csv(filename)

    epss = [df.loc[i, 'eps'] for i in range(df.shape[0])]

    return epss


import argparse

is_insdp = False
parser = argparse.ArgumentParser()
# Dataset Setting
parser.add_argument('--dataset', type=str, default='mnist')

if is_insdp:
    parser.add_argument('--folder_prefix',
                        type=str,
                        default='root/folder/path/')
else:
    parser.add_argument('--folder_prefix',
                        type=str,
                        default='root/folder/path/')

# Smoothing Setting
parser.add_argument('--N_m', type=int, default=1000)

if is_insdp:
    parser.add_argument('--epoch', type=int, default=3)
else:
    parser.add_argument('--epoch', type=int, default=3)
# Evaluate setting

if __name__ == '__main__':
    args = parser.parse_args()
    args = vars(args)
    print(args)

    model = MnistNet(num_of_classes=2).cuda()
    model.eval()
    dataPath = "./data/"
    test_dataset = datasets.MNIST(
        dataPath,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.1307,), (0.3081,))
        ]))
    test_dataset = PartDataset(test_dataset, [0, 1])
    print("test dataset size", len(test_dataset))
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=64,
                                              shuffle=False)


    saved_epoch = args['epoch']
    if is_insdp:
        saved_model_names = [
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
        ]
        noise = [1, 2, 3, 4, 5, 8, 10, 15]
    else:

        noises = [3, 2.7, 2.5, 2.3, 2.1, 1.8, 1.5, 1, 0.5, 1.7, 1.9, 0.6, 0.8]
        saved_model_names = [
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
        ]

    for i in range(len(saved_model_names)):
        saved_model_name = saved_model_names[i]
        epss = get_dp_result(args['folder_prefix'], saved_model_name)
        epsilon = epss[saved_epoch - 1]

        PREFIX = args[
            'folder_prefix'] + saved_model_name + f'model.pt.tar.epoch_{saved_epoch}.run_'
        pa_exp, pb_exp, is_acc = certificate_over_dataset(
            model, test_loader, PREFIX, args['N_m'])

        # prepare output file
        output_fname = os.path.join(
            args['folder_prefix'] + saved_model_name,
            "Epoch%dM%dEps%.4f.txt" % (saved_epoch, args['N_m'], epsilon))
        f = open(output_fname, 'w')
        print("idx\tpa_exp\tpb_exp\tis_acc", file=f, flush=True)

        for i in range(len(pa_exp)):  # len of test data set
            print("{}\t{}\t{}\t{}".format(i, pa_exp[i], pb_exp[i], is_acc[i]),
                  file=f,
                  flush=True)

        logger.info("is_acc %.4f " % (float(sum(is_acc)) / len(is_acc)))
        f.close()
        logger.info("save to %s" % output_fname)
