import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
from sklearn.datasets import make_moons
from nets import MLP
import torchvision

import matplotlib.pyplot as plt

from collections import Counter

import medmnist
from medmnist import INFO

import seaborn as sns

torch.manual_seed(1)



"""Get parser object."""
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
parser = ArgumentParser(
    description=__doc__, formatter_class=ArgumentDefaultsHelpFormatter
)
parser.add_argument(
    "-d",
    "--data",
    dest="data",
    default="breastmnist",
    help="dataset",
    metavar="DATA",
    required=False,
)
parser.add_argument(
    "-de",
    "--device",
    dest="device",
    default="cuda:0",
    help="device",
    metavar="DEVICE",
    required=False,
)

args = parser.parse_args()
data = args.data
device = args.device


nt  = 10
t   = torch.linspace(0,0.01,nt).to(device)
dt = (t[1] - t[0]).item()

N   = 50000
N_t = 10000

width = 512
depth = 1

if data == 'moons':
    X_data, y_data = make_moons(N, noise=0.25)
    X_data = torch.tensor(X_data).float()
    y_data = torch.tensor(y_data)
    X_data[-100:] = X_data[-100:] + 2 * (2*(y_data[-100:] > 0) - 1).unsqueeze(-1) * torch.ones(100, 2)

elif data == 'breastmnist' or data == 'bloodmnist'  \
  or data == 'dermamnist'  or data == 'retinamnist' \
  or data == 'pathmnist'   or data == 'organcmnist' \
  or data == 'organamnist' or data == 'organsmnist' \
  or data == 'tissuemnist' or data == 'octmnist':

    tsfm = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[.5], std=[.5])
    ])
    DataClass = getattr(medmnist, INFO[data]['python_class'])

    ds   = DataClass(split='train', transform=tsfm, download=True, root='/path/to/data/')
    ds_t = DataClass(split='val',  transform=tsfm, download=True, root='/path/to/data/')

    X_data = torch.stack([d[0] for d in ds], dim=0).reshape(len(ds), -1)
    y_data = torch.cat([torch.tensor(d[1]) for d in ds], dim=0)
    print('Class Statistics:')
    print(Counter(y_data.tolist()))

    X_data_t = torch.stack([d[0] for d in ds_t], dim=0).reshape(len(ds_t), -1)
    y_data_t = torch.cat([torch.tensor(d[1]) for d in ds_t], dim=0)


def bb(t, a, b, var=1):
    '''
    Samples a Brownian Bridge from a to b.
    '''

    dt = t[1] - t[0]

    dW = torch.randn(a.shape[0], t.shape[-1], a.shape[-1]).to(device) * dt.sqrt() * var
    W  = dW.cumsum(1)
    W[:,0] = 0
    W = W + a.unsqueeze(1)

    BB = W - t.unsqueeze(0).unsqueeze(-1) * (W[:,-1] - b).unsqueeze(1)

    return BB

def bb_simplex(t, a, b, var=1):
    bridge = bb(t, a, b).abs()
    s_bb = bridge / bridge.sum(-1, keepdims=True)
    return s_bb

def ce(fX, y):
    return ( -y * fX.log()).sum(-1)
    #_, y_ind = y.max(-1)
    #return F.cross_entropy(fX.log(), y_ind)

class Small(nn.Module):
    def __init__(self, d, k):
        super(Small, self).__init__()
        self.w = nn.Linear(d, k)

    def forward(self, x):
        return torch.softmax(self.w(x), -1)

iters = 500
runs  = 10
k  = y_data.max().int().item() + 1
d  = X_data.shape[-1]
bs = 500

plt.scatter(X_data[:,0], X_data[:,1],c=y_data)
plt.savefig('scatter.pdf')
plt.close('all')

beta_dist = torch.distributions.beta.Beta(0.2,0.2, validate_args=None)
label_noise = True

for type_ in ['none','mixup','bridge','bridge-if']:
    acc = torch.zeros(runs, iters)
    class_acc = torch.zeros(runs, k)
    norm_ids = torch.zeros(runs, k)
    norms = {i : [] for i in range(k)} # save the norm of the importance sampling
    for r_idx in range(runs):
        losses = []
        net = MLP(d, width, depth, k, bn=True).to(device)
        opt = optim.Adam(net.parameters(), lr=1e-3)

        for i_idx in range(iters):
            net.train()
            opt.zero_grad()

inds = torch.randperm(X_data.shape[0])

                X   = X_data[inds[:bs]].to(device)
                X_  = X_data[inds[-bs:]].to(device)
                reshuf_d = torch.cdist(X,X_,1).min(-1)[1]
                reshuf = torch.randperm(bs)
                X_ = X_[reshuf_d]
                y   = F.one_hot(y_data[inds[:bs]],  k).to(device)
                y_  = F.one_hot(y_data[inds[-bs:]], k).to(device)
                if label_noise:
                    y_  = y_[reshuf_d]
                    y_  = y_[reshuf]
                else:
                    y_  = y_[reshuf_d]
                y_oh= F.one_hot(y_data, k)

                orig_sh = (X.shape[0], t.shape[0], X.shape[-1])

                bbX = bb(t, X, X_, var=var).reshape(-1,d)
                bbX.requires_grad = True

                bby = bb_simplex(t, y, y_, var=var).reshape(-1,k)

                if type_ == 'bridge':
                    fX = net(bbX)
                    loss = ce(fX, bby).mean()
                elif type_ == 'bridge-dynkin':
                    fX = net(bbX)
                    dce = ce(fX,bby).reshape(y.shape[0],-1,1)
                    E_dce = dce.mean(0).unsqueeze(0)
                    loss = dce[:,0,0].mean() + ((E_dce[:,1:] - dce[:,:-1]) * dt).sum(0).mean()
                elif type_ == 'bridge-int':
                    fX = net(bbX)
                    dce = ce(fX,bby).reshape(y.shape[0],-1,1)
                    loss = dce[:,0,0].mean() + (dce).sum(1).mean()
                elif type_ == 'bridge-if':
                    fX = net(bbX)
                    l = ce(fX, bby).mean()
                    gfX =  grad(l, bbX, retain_graph=True, create_graph=True)[0] 
                    loss = ( ce(fX, bby) + ( (bbX * gfX ).sum(-1) - 0.5 * (( gfX ) ** 2).sum(-1)) ).sum(-1).mean()
                elif type_ == 'bridge-if-diff':
                    fX = net(bbX)
                    l = ce(fX, bby).mean()
                    gfX = grad(l, bbX, retain_graph=True)[0].reshape(orig_sh)
                    bbX = bbX.reshape(orig_sh)
                    dbbX = bbX[:,1:] - bbX[:,:-1]
                    loss = (ce(fX, bby).reshape(X.shape[0], nt)[:,:-1] - (dbbX / gfX[:,:-1]).sum(-1) + 0.5 * ((1 / gfX[:,:-1]) ** 2 * dt).sum(-1)).sum(-1).mean()
                elif type_ == 'mixup':
                    lambdas = beta_dist.sample(sample_shape=(X.shape[0],1)).to(device)
                    X_X = lambdas * X + (1-lambdas) * X_
                    fX = net(X_X)
                    loss = (lambdas * ce(fX, y) + (1 - lambdas) * ce(fX, y_)).mean()
                else:
                    X_X = torch.cat((X, X_))
                    y_y = torch.cat((y, y_))
                    fX = net(X_X)
                    loss = ce(fX, y_y).mean()

            losses.append(loss.item())

            loss.backward()
            opt.step()

            with torch.no_grad():
                net.eval()
                class_ = net(X_data_t.to(device)).max(-1)[1]
                accuracy = (class_.to('cpu') == y_data_t).float().mean().item()
                acc[r_idx, i_idx] = accuracy

            if i_idx == iters - 1:
                for K_idx in range(k):
                    class_inds = y_data == K_idx
                    X_in = X_data[class_inds].to(device)
                    X_in.requires_grad = True
                    fX = net(X_in)
                    l = ce(fX.to('cpu'), torch.ones(fX.shape[0], k) / k ).mean()
                    gfX = grad(l, X_in, retain_graph=True)[0] 
                    norm_id = gfX.abs().sum(-1)
                    norms[K_idx].extend(norm_id.cpu().tolist())
                    norm_ids[r_idx, K_idx] = norm_id.mean().item()

                    net.eval()
                    class_inds_t = y_data_t == K_idx
                    class_ = net(X_data_t[class_inds_t].to(device)).max(-1)[1]
                    accuracy = (class_.to('cpu') == y_data_t[class_inds_t]).float().mean().item()
                    class_acc[r_idx, K_idx] = accuracy

    plt.plot(torch.arange(acc.shape[-1]), acc.mean(0), label=type_)
    plt.fill_between(torch.arange(acc.shape[-1]), acc.mean(0) - acc.std(0), acc.mean(0) + acc.std(0), alpha=0.2)
    plt.legend()
    
print('Accuracy : {:.3f} ({:.4f})'.format(acc.mean(0)[-1], acc.std(0)[-1]))
for K_idx in range(k):
    print('Class {} accuracy {:.3f} \pm {:.4f} '.format(K_idx, class_acc[:,K_idx].mean(), class_acc[:,K_idx].std()))
print('order = {}'.format(norm_ids.mean(0).argsort()))

plt.savefig('losses_bridge_{}_{}x{}_iter{}.pdf'.format(data, width, depth, iters))
plt.close('all')


n_max = torch.tensor([torch.tensor(norms[K_idx]).max() for K_idx in range(k)]).max()
for K_idx in range(k):
    norm_ids_n = norm_ids / norm_ids.max()
    print('Norm Class {}: {} \pm {}'.format(K_idx, norm_ids_n[:,K_idx].mean(),  norm_ids_n[:,K_idx].std()))
    norms[K_idx] = (torch.tensor(norms[K_idx]) / n_max).tolist()

for K_idx in range(k):
    sns.kdeplot(norms[K_idx], label='Class {}'.format(K_idx))
plt.legend()
plt.xlabel('Uncertainty Score')
plt.tight_layout()
plt.savefig('{}_unc_hist.pdf'.format(data))
plt.close('all')

for K_idx in range(k):
    sns.kdeplot(norm_ids_n[:,K_idx], label='Class {}'.format(K_idx))
plt.legend()
plt.xlabel('Uncertainty Score')
plt.tight_layout()
plt.savefig('{}_unc_hist_avg.pdf'.format(data))
plt.close('all')

import pickle
pickle.dump(norms, open('{}_unc.pkl'.format(data), 'wb'))
pickle.dump(norm_ids_n, open('{}_unc_hist.pkl'.format(data), 'wb'))


