import torch
import torch.nn as nn
import numpy as np
from time import time
import os
import sys
from collections import OrderedDict
import matplotlib.pyplot as plt
import argparse 
import torch.utils.data
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

from train import TrainerVaDE
from datetime import datetime
from pathlib import Path

from sklearn.datasets import make_moons


from preprocess import get_mnist


# Based on GIN
def draw(latent, data, labels, z = None, preds = None, dir=None):
    fig = plt.figure(figsize=(12, 3.5))

    if preds is None:
        plt.subplot(1, 2, 1)
        plt.scatter(latent[:,0], latent[:,1], c=labels, s=6, alpha=0.3)
        plt.xticks([])
        plt.yticks([])
        plt.title('Ground truth', fontsize=16, family='serif')
        
        plt.subplot(1, 2, 2)
        plt.scatter(data[:,0], data[:,1], c=labels, s=6, alpha=0.3)
        plt.xticks([])
        plt.yticks([])
        plt.title('Observed data', fontsize=16, family='serif')
    else:
        plt.subplot(1, 3, 1)
        plt.scatter(latent[:,0], latent[:,1], c=labels, s=6, alpha=0.3)
        plt.xticks([])
        plt.yticks([])
        plt.title('Original latents', fontsize=16, family='serif')
        
        plt.subplot(1, 3, 2)
        plt.scatter(data[:,0], data[:,1], c=labels, s=6, alpha=0.3)
        plt.xticks([])
        plt.yticks([])
        plt.title('Observed data', fontsize=16, family='serif')

        plt.subplot(1, 3, 3)
        plt.scatter(z[:,0], z[:,1], c=preds, s=6, alpha=0.3)
        plt.xticks([])
        plt.yticks([])
        plt.title('Predicted latents', fontsize=16, family='serif')

    plt.tight_layout()
    plt.savefig(dir+'og.png')
    # plt.show()
    plt.close()
    
####################################################################
def draw_all(latent, latent_gmm, data, data_hat, labels_gmm, preds, data_labels, epoch):
    fig = plt.figure(figsize=(12, 3.5))
    plt.subplot(2, 2, 1)
    plt.scatter(latent[:,0], latent[:,1], c=preds, s=6, alpha=0.3)
    plt.xticks([])
    plt.yticks([])
    plt.title("Encoder's z", fontsize=16, family='serif')
        
    plt.subplot(2, 2, 2)
    plt.scatter(latent_gmm[:,0], latent_gmm[:,1], c=labels, s=6, alpha=0.3)
    plt.xticks([])
    plt.yticks([])
    plt.title('Latent GMM', fontsize=16, family='serif')

    plt.subplot(2, 2, 3)
    plt.scatter(data_hat[:,0], data_hat[:,1], c=preds, s=6, alpha=0.3)
    plt.xticks([])
    plt.yticks([])
    plt.title("Decoder's output", fontsize=16, family='serif')
    
    plt.subplot(2, 2, 4)
    plt.scatter(data[:,0], data[:,1], c=data_labels, s=6, alpha=0.3)
    plt.xticks([])
    plt.yticks([])
    plt.title("Observed distribution", fontsize=16, family='serif')

    plt.tight_layout()
    plt.savefig(dir + 'learnt_latents_{}.png'.format(epoch))
    # plt.show()
    plt.close()
############################################################################

def init_weights(m):
    if isinstance(m, nn.Linear):
#         torch.nn.init.xavier_uniform(m.weight)
#         m.bias.data.fill_(0.01)
        torch.nn.init.orthogonal_(m.weight.data)

def get_model():
    hidden_dim = 4
    model = nn.Sequential(
        nn.Linear(2, hidden_dim),
        nn.LeakyReLU(),
        # nn.Linear(hidden_dim, hidden_dim),
        # nn.LeakyReLU(),
        # nn.Linear(hidden_dim, hidden_dim),
        # nn.LeakyReLU(),
        nn.Linear(hidden_dim, 2)
        )
    model.apply(init_weights)
    return model



def get_4squaredata(dim, n_data_points, noise_var):
    means = [[0, 0], [0, 3], [3, 0], [5, 5]]
    means = torch.FloatTensor(means)
    labels = torch.randint(4, size=(n_data_points,))
    orient = torch.rand(4, 2)
    data = means[labels]+torch.rand(n_data_points, 2)*orient[labels]
    return data, labels


def get_3squaredata(dim, n_data_points, noise_var):
    means = [[0, 0], [0, 3], [3, 0]]
    means = torch.FloatTensor(means)
    labels = torch.randint(3, size=(n_data_points,))
    orient = torch.FloatTensor([[0.7, 0.3], [0.9, 0.8], [0.2, 1.3]])
    data = torch.FloatTensor(np.zeros((n_data_points, dim)))
    data[:, :2] = means[labels]+torch.rand(n_data_points, 2)*orient[labels]
    data = data + 0.1*torch.randn(n_data_points, dim)
    return data, labels

def get_dif3squaredata(dim, n_data_points, noise_var):
    means = [[0, 0], [0, 3], [3, 0]]
    means = torch.FloatTensor(means)
    labels = torch.randint(3, size=(n_data_points,))
    orient = torch.FloatTensor([[1.5, 0.5], [1.75, 2], [0.5, 3]])+0.1*torch.randn(3, 2)
    rot = torch.FloatTensor([[[1.3, 1.7], [-1.7, 1.3]], [[1.7, 1.3], [-1.3, 1.7]], [[1, 0], [0, 1]]])+0.1*torch.randn(3, 2, 2)
    
    full_data = torch.FloatTensor(np.zeros((n_data_points, dim)))
       
    data = torch.rand(n_data_points, 2)*orient[labels]
    for i in range(n_data_points):
        data[i] = torch.matmul(data[i], rot[labels[i]])
    data = means[labels]+data
    
    full_data[:, 1:3] = data
    full_data += noise_var*torch.randn(n_data_points, dim)
    return full_data, labels

def get_rand3squaredata(dim, n_data_points, noise_var):
    means = [[0, 0], [0, 3], [3, 0]]
    means = torch.FloatTensor(means)
    labels = torch.randint(3, size=(n_data_points,))
    orient = torch.FloatTensor([[1.5, 0.5], [2, 2], [0.5, 3]])+0.15*torch.randn(3, 2)
    x = np.random.rand(3)
    y = np.sqrt(1-x*x)
    rot = torch.FloatTensor([[[x[0], y[0]], [-y[0], x[0]]], [[x[1], y[1]], [-y[1], x[1]]], [[x[2], y[2]], [-x[2], y[2]]]])
    
    full_data = torch.FloatTensor(np.zeros((n_data_points, dim)))
       
    data = torch.rand(n_data_points, 2)*orient[labels]
    for i in range(n_data_points):
        data[i] = torch.matmul(data[i], rot[labels[i]])
    data = means[labels]+data
    
    full_data[:, 1:3] = data
    full_data += noise_var*torch.randn(n_data_points, dim)
    return full_data, labels

def get_intersecting3(dim, n_data_points, noise_var):
    means = [[0, 0], [5, 5], [0, 0]]
    means = torch.FloatTensor(means)
    labels = torch.randint(3, size=(n_data_points,))
    orient = torch.FloatTensor([[7, 0.5], [2, 2], [0.5, 7]])
    #rot = torch.FloatTensor([[[1.3, 1.7], [-1.7, 1.3]], [[1.7, 1.3], [-1.3, 1.7]], [[1, 0], [0, 1]]])
    
    full_data = torch.FloatTensor(np.zeros((n_data_points, dim)))
       
    data = (torch.randn(n_data_points, 2))*orient[labels]
    #for i in range(n_data_points):
    #    data[i] = torch.matmul(data[i], rot[labels[i]])
    data = means[labels]+data
    
    full_data[:, 1:3] = data
    full_data += 0.1*torch.randn(n_data_points, dim)
    return full_data, labels

def make_pinwheel_data(radial_std, tangential_std, num_classes, num_per_class, rate):
    # code from Johnson et. al. (2016)
    rads = np.linspace(0, 2*np.pi, num_classes, endpoint=False)

    np.random.seed(1)

    features = np.random.randn(num_classes*num_per_class, 2) \
        * np.array([radial_std, tangential_std])
    features[:,0] += 1.
    labels = np.repeat(np.arange(num_classes), num_per_class)

    angles = rads[labels] + rate * np.exp(features[:,0])
    rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
    rotations = np.reshape(rotations.T, (-1, 2, 2))

    feats = 10 * np.einsum('ti,tij->tj', features, rotations)

    data = np.random.permutation(np.hstack([feats, labels[:, None]]))
    labels = np.zeros((num_classes*num_per_class, ))
    labels = data[:, 2].astype(int)
    return torch.Tensor(data[:, 0:2]), torch.Tensor(labels).int()

def get_dataloader(data, labels, batch_size=128):
    dataloader=DataLoader(TensorDataset(data, labels), batch_size=batch_size, 
                          shuffle=True, num_workers=1)
    return dataloader


if __name__ == "__main__":
    time = datetime.now()
    dir = "./outputs/{}_{}_{}_{}/".format(time.day, time.hour, time.minute, time.second)
    Path(dir).mkdir(parents=True, exist_ok=True)

    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=20,
                        help="number of iterations")
    parser.add_argument("--patience", type=int, default=50, 
                        help="Patience for Early Stopping")
    parser.add_argument('--lr', type=float, default=2e-4,
                        help='learning rate')
    parser.add_argument("--batch_size", type=int, default=64,
                        help="Batch size")
    parser.add_argument('--pretrain', type=int, default=1,
                        help='learning rate')
    parser.add_argument('--output_dir', type=str, default=dir,
                        help='Output dir')
    parser.add_argument('--pretrained_path', type=str, 
                        default=dir + "pretrained_parameters.pth",
                        help='Output path')
    parser.add_argument("--in_dim", type=int, default=5, 
                        help="Input dimension")
    parser.add_argument("--latent_dim", type=int, default=2,
                        help="Latent dimension")
    parser.add_argument("--n_classes", type=int, default=3, 
                        help="Num classes")
    parser.add_argument("--noise_var", type=float, default=0.001, 
                        help="Variance of Gaussian noise")
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Running on", device)

    
    mdir = "./saved_models/{}_{}_{}_{}/".format(time.day, time.hour, time.minute, time.second)
    Path(mdir).mkdir(parents=True, exist_ok=True)
    
    num_runs = 5
    data, labels = get_rand3squaredata(args.in_dim, 5000, 0.01)
    #data, labels = make_pinwheel_data(0.25, 0.05, args.n_classes, 500, 0.2)
    #args.in_dim = 2
    output_dir = args.output_dir
    
    plt.scatter(data[:, 1], data[:, 2], c = labels )
    plt.show()
    
    torch.save(data, output_dir+"data.pth")
    torch.save(labels, output_dir+"labels.pth")
    
    lab_run = time.minute
    for run in range(num_runs):
        args.output_dir = output_dir+str(run)
        dataloader = get_dataloader(data, labels, batch_size=args.batch_size)
    
        vade = TrainerVaDE(args, device, dataloader, covariance = "full")
        if args.pretrain==True:
           vade.pretrain()

        vade.train()
    
    
        torch.save(vade.VaDE.state_dict(), mdir+"full{}_{}.pth".format(time.minute, str(run)))
        encoded_data = []
        for x, true_label in dataloader:
            encoded_data_x, _ = vade.VaDE.encode(x)
            encoded_data.extend(encoded_data_x.detach().numpy())
        encoded_data = np.vstack(encoded_data)
        print(encoded_data.shape)
        torch.save(torch.Tensor(encoded_data), args.output_dir+"encoded.pth")
