import torchvision
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms import RandomRotation
from torchvision.transforms.functional import rotate, to_pil_image

from sklearn.svm import LinearSVC, SVC
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier 
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_moons

import seaborn as sns

import nets as nets
from brownian_utils import brownian_bridge_ab, compute_sample_grads, importance_weights
from utils import *

import matplotlib.pyplot as plt 
import numpy as np

args = get_parser().parse_args()

device = "cuda:7"
mixup_type = args.mixup_type


epochs = 100
bs = 200
T  = 1
eps = 1e-6
d   = 2
h   = 4
d_z = 2
n_l = 2
bn  = True
n_t = 10



mlp = nets.MLP (d, h, n_l, d_z, bn=bn).to(device)
opt = optim.Adam(mlp.parameters(), lr=1e-4)
sched = optim.lr_scheduler.ExponentialLR(opt, 0.9998)

t = torch.linspace(0, T, n_t).unsqueeze(0).unsqueeze(-1).repeat(bs, 1, d).float().to(device)
t_label = torch.linspace(0, T, n_t).unsqueeze(0).unsqueeze(-1).repeat(bs, 1, d_z).float().to(device)
dt = T / n_t
var_t = t[:,:,0] * (T - t[:,:,0]) / T

label_t_s = 1 - 4 * var_t / T  
label_t_e = t[:,:,0] / T + eps
label_t   = label_t_s

plt.plot(label_t[0].cpu().numpy())
plt.savefig('var.pdf')
plt.close('all')
spiral=False
unbalance=False
def make_spirals(n_samples, noise):
    import numpy as np
    from numpy import pi
    theta = np.sqrt(np.random.rand(n_samples))*2*pi # np.linspace(0,2*pi,100)

    r_a = 2*theta + pi
    data_a = np.array([np.cos(theta)*r_a, np.sin(theta)*r_a]).T
    x_a = data_a + np.random.randn(n_samples,2)*noise

    r_b = -2*theta - pi
    data_b = np.array([np.cos(theta)*r_b, np.sin(theta)*r_b]).T
    x_b = data_b + np.random.randn(n_samples,2)*noise

    res_a = np.append(x_a, np.zeros((n_samples,1)), axis=1)
    res_b = np.append(x_b, np.ones((n_samples,1)), axis=1)

    res = np.append(res_a, res_b, axis=0)
    np.random.shuffle(res)
    
    return res[:,0:2], res[:,-1]


def make_moons(n_samples, noise=None, shuffle=True, use_noise=False):
    import numpy as np
    from numpy import pi
    n_samples_out = n_samples//2
    n_samples_in = n_samples//2
    outer_circ_x = np.cos(np.linspace(0, np.pi, n_samples_out))
    outer_circ_y = np.sin(np.linspace(0, np.pi, n_samples_out))
    inner_circ_x = 1 - np.cos(np.linspace(0, np.pi, n_samples_in))
    inner_circ_y = 1 - np.sin(np.linspace(0, np.pi, n_samples_in)) - 0.5
    
    if use_noise==False:
        mask = np.zeros(n_samples_out,dtype=np.bool_)
        rand_indx = np.random.choice(list(range(n_samples_out)), size=int(n_samples_out/2), replace=False)
        mask[rand_indx] = True
        outer_circ_y[mask] = outer_circ_y[mask] + 2*noise
        outer_circ_x[mask] = outer_circ_x[mask] + outer_circ_x[mask]*2*noise
        
        outer_circ_y[~mask] = outer_circ_y[~mask] - 2*noise
        outer_circ_x[~mask] = outer_circ_x[~mask] - outer_circ_x[mask]*1.5*noise
    
        
        mask = np.zeros(n_samples_out,dtype=np.bool_)
        rand_indx = np.random.choice(list(range(n_samples_out)), size=int(n_samples_out/2), replace=False)
        mask[rand_indx] = True
        inner_circ_y[mask] = inner_circ_y[mask] + 2*noise
        inner_circ_x[mask] = inner_circ_x[mask] - (inner_circ_x[mask]-1)*2*noise
        
        inner_circ_y[~mask] = inner_circ_y[~mask] - 2*noise
        inner_circ_x[~mask] = inner_circ_x[~mask] + (inner_circ_x[~mask]-1)*1.5*noise

    
    X = np.vstack(
        [np.append(outer_circ_x, inner_circ_x), np.append(outer_circ_y, inner_circ_y)]
    ).T
    y = np.hstack(
        [np.zeros(n_samples_out, dtype=np.intp), np.ones(n_samples_in, dtype=np.intp)]
    )
    if shuffle:
        rand_indx = np.random.choice(list(range(n_samples)), size=n_samples , replace=False)
        X, y = X[rand_indx], y[rand_indx]
    if use_noise != False:
        X += np.random.normal(scale=noise, size=X.shape)
    return X, y

# Sampling Training
moons_x_t, moons_y_t = make_moons(n_samples=1000, noise=0.1, use_noise=False)
moons = torch.utils.data.TensorDataset(torch.tensor(moons_x_t).float(), torch.tensor(moons_y_t).float(),
                                         torch.tensor(list(range(1000))).long())

# Sampling Validation
moons_x, moons_y = make_moons(n_samples=2000, noise=0.1, use_noise=True)
moons_v = torch.utils.data.TensorDataset(torch.tensor(moons_x).float(), torch.tensor(moons_y).float())

dataset = None
dataset   = moons
dataset_v = moons_v

dataloader   = DataLoader(dataset,   batch_size=bs, shuffle=True)
dataloader_v = DataLoader(dataset_v, batch_size=bs, shuffle=False)

losses     = []
val_losses = []
acc        = []
val_acc    = []


beta_dist = torch.distributions.beta.Beta(1, 1, validate_args=None)
if args.mixup_type == "umix":
    weights_list=[[0]*1000]*epochs
    weights_list=torch.tensor(weights_list).to(device).float()
for epoch in range(epochs):
    mlp.train()
    tot_loss = 0
    for idx, (data, label, indices) in enumerate(dataloader):
        opt.zero_grad()

        data = data.squeeze(1).float()
        data = data.reshape(bs, -1).to(device)
        data.requires_grad = True
        label = label.to(device)
        label_non_onehot = label
        label = F.one_hot(label.long(), num_classes=d_z).to(device).float()
        if mixup_type == "none":
            # No data Augmentation
            data = data.reshape(-1, d)
        if mixup_type == "bridge":
            # Data Augmentation with Brownian Bridge
            input_shuffle = torch.randperm(data.shape[0])
            
            a = data  # define bridge endpoints as first and second half of batch
            b = data[input_shuffle]
            
            a_label = label
            b_label = label[input_shuffle]

            var = 0.3
            bridge = (brownian_bridge_ab(t, a, b, var)[0]).reshape(-1, d)
            bridge_label = brownian_bridge_ab(t_label, a_label, b_label, var, simplex=True)[0].reshape(-1, d_z)
            label = label.unsqueeze(1).repeat(1,t_label.shape[1],1).reshape(-1,d_z)
            label_shuffle = label[input_shuffle].unsqueeze(1).repeat(1,t_label.shape[1],1).reshape(-1,d_z)
            label_non_onehot_rep = label_non_onehot.unsqueeze(1).repeat(1,t_label.shape[1],1).reshape(n_t,data.shape[0]).T.flatten()
            label_non_onehot_shuffle_rep = label_non_onehot[input_shuffle].unsqueeze(1).repeat(1,t_label.shape[1],1).reshape(n_t,data.shape[0]).T.flatten()
            lambdas = ((T - t_label[...,0])/T).reshape(-1,)
            label = bridge_label
            data = bridge
        if mixup_type == "original":
            # Original Mixup from Zhang et.al 2018
            input_shuffle = torch.randperm(data.shape[0])
            lambdas = beta_dist.sample(sample_shape=(data.shape[0],1)).to(device)
            data = lambdas*data + (1-lambdas)*data[input_shuffle]
        if mixup_type == "umix" and epoch < 50:
            # No data Augmentation
            data = data.reshape(-1, d)
        if mixup_type == "umix" and epoch >= 50:
            # umix after some epoch, do mixup
            input_shuffle = torch.randperm(data.shape[0])
            lambdas = beta_dist.sample(sample_shape=(data.shape[0],1)).to(device)
            data = lambdas*data + (1-lambdas)*data[input_shuffle]
        
        # MLP outputs 
        fxn = mlp(data, t[:,:,-1].reshape(-1, 1))#.reshape(-1, d_z)
        fx  = fxn

        # getting the weights
        weights = None
        if mixup_type == "bridge":
            loss = F.cross_entropy(fx, label).mean()
            sample_grads = torch.autograd.grad(loss, data, retain_graph=True)[0]#.detach()
            weights = (sample_grads * data).reshape(bs, n_t, *data.shape[1:]).sum(-1) - \
                        0.5*(sample_grads**2).reshape(bs, n_t, *data.shape[1:]).sum(-1)
        if mixup_type == "umix":
            pred_label = torch.argmax(F.sigmoid(fx),1)
            weights_list[epoch, indices] += (~(pred_label == label_non_onehot)).float()
            
            
        # (weighted) loss for gradient updates
        if mixup_type == "none":
            loss = F.cross_entropy(fx, label, reduction='mean')
        elif mixup_type == "bridge":
            loss = torch.mean(F.cross_entropy(fx, label) - weights.sum(-1))
        elif mixup_type == "original":
            loss = torch.mean(lambdas*F.cross_entropy(fx, label, reduction='none') + \
                (1-lambdas)*F.cross_entropy(fx, label[input_shuffle], reduction='none'))
        elif mixup_type == "umix" and epoch < 50:
            loss = F.cross_entropy(fx, label, reduction='mean')
        elif mixup_type == "umix" and epoch >= 50:
            currweights = weights_list[epoch-50:epoch,indices].mean(0)
            currweights_j = currweights[input_shuffle]
            label_j = label[input_shuffle]
            loss = torch.mean(
                currweights*lambdas*F.cross_entropy(fx, label, reduction='none') + \
                currweights_j*(1-lambdas)*F.cross_entropy(fx, label_j, reduction='none')
                )
        
        tot_loss += loss.item()
        loss.backward()
        opt.step()

    losses.append(tot_loss/(idx+1))
    if weights is not None:
        weights = weights.sum(-1).detach().cpu().numpy().flatten()
        sns.kdeplot(x=weights, hue=label_non_onehot.detach().cpu().numpy())
        plt.savefig('weight_density.png')
        plt.close('all')
    mlp.eval()
    # Validation
    with torch.no_grad():
        loss = 0
        accuracy = 0
        for idx, (data_v, label_v) in enumerate(dataloader_v):
            data_v = data_v.reshape(bs, -1).float().to(device)
            fxn = mlp(data_v, torch.zeros_like(data[:,-1].unsqueeze(-1))).cpu().detach()
            fx  = fxn
            label_v_onehot = F.one_hot(label_v.long(), num_classes=d_z).float()
            if mixup_type == 'none':
                loss += F.cross_entropy(fx, label_v_onehot, reduction="mean").item()
            if mixup_type == 'bridge':
                loss += (F.cross_entropy(fx, label_v_onehot, reduction="sum")/fx.shape[0]).item()
            if mixup_type == 'original':
                loss += F.cross_entropy(fx, label_v_onehot, reduction="mean").item()
            pred_label = torch.argmax(F.sigmoid(fx),1)
            accuracy += torch.sum(pred_label == label_v)/pred_label.shape[0]
        acc.append(accuracy/(idx+1))
        val_losses.append(loss/(idx+1))
        print('===> Epoch {}: Acc: {:.4f} Val Loss: {:.4f}'.format(epoch, acc[-1], val_losses[-1]))
        plt.plot(range(epoch+1), acc)
        plt.savefig('{}_acc.pdf'.format(mixup_type))
        plt.close('all')
        if d == 2:
            savefiguredata = ()
            positions = moons_x
            fx = mlp(torch.tensor(positions).float().to(device))
            label_v_onehot = F.one_hot(torch.tensor(moons_y).long().to(device), num_classes=d_z).float()
            loss = F.cross_entropy(fx, label_v_onehot, reduction="none").detach().cpu().numpy()
            savefiguredata += (positions, moons_y, loss,)
            
            positions_t = moons_x_t
            fx = mlp(torch.tensor(positions_t).float().to(device))
            label_v_onehot = F.one_hot(torch.tensor(moons_y_t).long().to(device), num_classes=d_z).float()
            loss_t = F.cross_entropy(fx, label_v_onehot, reduction="none").detach().cpu().numpy()
            savefiguredata += (positions_t, moons_y_t, loss_t,)
            print(len(savefiguredata))
            
            x = np.arange(-17.0,17.0,0.1) if spiral else np.arange(-5.0,5.0,0.1)
            y = np.arange(-17.0,17.0,0.1) if spiral else np.arange(-5.0,5.0,0.1)
            xx,yy = np.meshgrid(x, y) # grid of point
            positions = np.c_[xx.ravel(), yy.ravel()]
            fx = F.sigmoid(mlp(torch.tensor(positions).float().to(device))).detach().cpu().numpy()
            savefiguredata += ((xx,yy), fx[:,0].reshape(xx.shape),)
            import pickle
            with open('temp_{}_{}_{}_fig1.pickle'.format("unbalance" if unbalance else "balance",
                                                 "spirals" if spiral else "", 
                                                 args.mixup_type), 'wb') as handle:
                pickle.dump(savefiguredata, handle, protocol=pickle.HIGHEST_PROTOCOL)
            

    plt.plot(losses)
    plt.plot(val_losses)
    plt.savefig('temp_{}_losses.pdf'.format(mixup_type))
    plt.close('all')