import numpy as np
import torch
import tensorflow as tf
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class Distribution:

    def __init__(self):
        pass

    def sample(self):
        pass

class UniformDistribution(Distribution):

    def __init__(self, corner_a, corner_b):
        super().__init__()
        self.corner_a = corner_a.cpu().detach().numpy()
        self.corner_b = corner_b.cpu().detach().numpy()
        self.dim = corner_a.shape
        assert corner_b.shape == self.dim
    
    def sample(self):
        return (np.random.uniform(size=self.dim) * (self.corner_b - self.corner_a) + self.corner_a).astype(np.float32)

class DistributionDataset(Dataset):
    def __init__(self, distribution, length, save=False):
        self.distribution = distribution
        self.length = length
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        return self.distribution.sample()

def naive_monte_carlo(model, expected, dist, num_samples=1000):
    loader = DataLoader(DistributionDataset(dist, num_samples), batch_size=256, num_workers=0)
    print("EXPECTED", expected)
    correct = 0
    total = 0
    for idx, batch in enumerate(tqdm(loader)):
        output = model(batch).cpu().detach().numpy()
        #print("OUTPUT:", output)
        output = np.argmax(output, axis=1)
        correct += np.sum(output == expected)
        total += output.shape[0]
    print("CORRECT\t{:.4f}".format(correct/total))
    print("INCORRECT\t{:.4f}".format(1-correct/total))
    return correct, total

def compare_monte_carlo(model, true_label, target, dist, num_samples):
    # 8192 was used for 3 x [20] model
    #loader = DataLoader(DistributionDataset(dist, num_samples), batch_size=8192, num_workers=2)
    loader = DataLoader(DistributionDataset(dist, num_samples), batch_size=256, num_workers=2)
    correct = np.zeros_like(true_label.cpu().detach().numpy())
    total = 0
    for idx, batch in enumerate(tqdm(loader)):
        batch = batch.to(model.device)
        dims = batch.size()
        batch = batch.reshape((-1,) + dims[2:])
        output = model(batch).cpu().detach().numpy()
        output = output.reshape(dims[:2]+(10,))
        # output = np.argmax(output, axis=-1)
        # correct += np.sum(output == true_label.cpu().detach().numpy()[None,:], axis=0)
        output = output[:,np.arange(dims[1]),true_label.cpu().detach().numpy()] - output[:,np.arange(dims[1]),target.cpu().detach().numpy()] # np.argmax(output, axis=-1)
        correct += np.sum(output > 0, axis=0)
        # print(correct.shape)
    print("CORRECT", correct)
    return correct, total

def get_radius_monte_carlo(model, true_label, target, input, num_samples, num_threshold):
    eps_low = torch.zeros(input.size(), dtype=torch.float32, device=input.device)
    eps_high = torch.ones(input.size(), dtype=torch.float32, device=input.device)

    while (eps_high - eps_low).max() > 1e-6:
        eps_mid = (eps_low + eps_high)/2
        lb, other = compare_monte_carlo(model, true_label, target, UniformDistribution(input-eps_mid, input+eps_mid), num_samples)
        eps_low[lb >= num_threshold] = eps_mid[lb >= num_threshold]
        eps_high[np.logical_not(lb >= num_threshold)] = eps_mid[np.logical_not(lb >= num_threshold)]
    print("EPS_LOW", eps_low)
    return eps_low[:,0].flatten(), {}