import os
import math
from re import S
from matplotlib import legend
from matplotlib.legend import Legend

import numpy as np
import torch
from tqdm import tqdm
import logging
logger = logging.getLogger()
import math
from typing import *
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import pandas as pd
import matplotlib.pyplot as plt


class PartDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, lab_list):
        self.dataset = dataset
        self.used_ids = []
        self.lab_list = lab_list
        self.lab_map = dict()
        for i in range(len(lab_list)):
            self.lab_map[lab_list[i]] = i
        print("lab_map", self.lab_map)
        for i, (X, y) in enumerate(dataset):
            if y in self.lab_list:
                self.used_ids.append(i)

    def __len__(self, ):
        return len(self.used_ids)

    def __getitem__(self, i):
        X, y = self.dataset[self.used_ids[i]]
        y_new = self.lab_map[y]
        return X, y_new


class CifarNet(nn.Module):
    def __init__(self, num_of_classes=10):
        super(CifarNet, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(start_dim=1, end_dim=-1),
        )
        self.fc = nn.Linear(128, num_of_classes, bias=True)

    def forward(self, x):
        x = self.net(x)
        x = self.fc(x)

        return x


def cal_bound(eps, delta, barc, ja, jb):
    logterm = (ja * (math.exp(eps) - 1) +
               delta * barc) / (jb * (math.exp(eps) - 1) + delta * barc)
    k = 1 / (2 * eps) * math.log(logterm)
    return k


def cal_ja(eps, delta, jb):
    ja = math.exp(2 * eps) * jb + (1 + math.exp(eps)) * delta
    return ja


def certificate_over_dataset(model, dataloader, PREFIX, N_m):
    model_preds = []
    labs = []
    sm = nn.Softmax(dim=1)

    for _ in tqdm(range(N_m)):

        model.load_state_dict(torch.load(PREFIX + '%d' % _)['state_dict'])

        all_pred = np.zeros((0, 2))
        for x_in, y_in in dataloader:
            x_in = x_in.cuda()
            output = model(x_in).squeeze(
                1)  # result after linear layer # [bs, 2]
            pred = sm(output).detach().cpu().numpy()  # [bs, 2]

            if (_ == 0):
                labs = labs + list(y_in.numpy())
            all_pred = np.concatenate([all_pred, pred], axis=0)

        model_preds.append(
            all_pred
        )  # all_pred: [num_samples, 2] ;  # model_preds: [num_models]

    gx = np.array(model_preds).mean(0)  # mean for all models
    labs = np.array(labs)  # [num_samples]

    pa = gx.max(1)
    pred_c = gx.argmax(1)  # [num_samples]

    gx[np.arange(len(pred_c)), pred_c] = -1
    pb = gx.max(1)  # [num_samples]

    is_acc = (pred_c == labs)

    return pa, pb, is_acc


def get_dp_result(folder_prefix, saved_model_name):
    filename = folder_prefix + saved_model_name + '/all_exp.csv'
    print(saved_model_name)
    df = pd.read_csv(filename)

    epss = [df.loc[i, 'eps'] for i in range(df.shape[0])]

    return epss


import argparse

is_insdp = True
parser = argparse.ArgumentParser()
# Dataset Setting
parser.add_argument('--dataset', type=str, default='mnist')

if is_insdp:
    parser.add_argument('--folder_prefix',
                        type=str,
                        default='root/folder/path/')
else:
    parser.add_argument('--folder_prefix',
                        type=str,
                        default='root/folder/path/')

# Smoothing Setting
parser.add_argument('--N_m', type=int, default=1000)
parser.add_argument('--epoch', type=int, default=1)
# Evaluate setting

if __name__ == '__main__':
    args = parser.parse_args()
    args = vars(args)
    print(args)

    model = CifarNet(num_of_classes=2).cuda()
    model.eval()
    dataPath = "./data/"
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    test_dataset = datasets.CIFAR10(dataPath,
                                    train=False,
                                    transform=transform_test)
    test_dataset = PartDataset(test_dataset, [0, 2])
    print("test dataset size", len(test_dataset))
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=64,
                                              shuffle=False)

    # Calculate the expectation and bound of p_A and p_B
    saved_epoch = args['epoch']
    if is_insdp:
        saved_model_names = [
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
        ]
        noises = [1, 2, 3, 4, 5, 6, 7, 8]
    else:
        saved_model_names = [
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
            'path/to/models/',
        ]
        noises = [
            0.5,
            0.8,
            1,
            1.3,
            1.7,
            2,
            2.3,
            2.6,
            3,
            4,
        ]

    for i in range(len(saved_model_names)):
        saved_model_name = saved_model_names[i]
        epss = get_dp_result(args['folder_prefix'], saved_model_name)
        epsilon = epss[saved_epoch - 1]

        PREFIX = args[
            'folder_prefix'] + saved_model_name + f'model.pt.tar.epoch_{saved_epoch}.run_'
        pa_exp, pb_exp, is_acc = certificate_over_dataset(
            model, test_loader, PREFIX, args['N_m'])

        # prepare output file
        output_fname = os.path.join(
            args['folder_prefix'] + saved_model_name,
            "Epoch%dM%dEps%.4f.txt" % (saved_epoch, args['N_m'], epsilon))
        f = open(output_fname, 'w')
        print("idx\tpa_exp\tpb_exp\tis_acc", file=f, flush=True)

        for i in range(len(pa_exp)):  # len of test data set
            print("{}\t{}\t{}\t{}".format(i, pa_exp[i], pb_exp[i], is_acc[i]),
                  file=f,
                  flush=True)

        logger.info("is_acc %.4f " % (float(sum(is_acc)) / len(is_acc)))
        f.close()
        logger.info("save to %s" % output_fname)
