import os
import argparse
import logging
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import grad, Variable
import time
import random
import matplotlib.pyplot as plt
from torchvision.utils import save_image

parser = argparse.ArgumentParser()
parser.add_argument('--nepochs', type=int, default=160)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--batch_size', type=int, default=1000)
parser.add_argument('--test_batch_size', type=int, default=1000)
parser.add_argument('--save', type=str, default='./experiment1')
parser.add_argument('--rectifier', type=str, default='ReLU')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--divide', type=int, default=1)
parser.add_argument('--weight_decay', type=float, default=0)
parser.add_argument('--data_type', type=str, default='mnist')
parser.add_argument('--nums', type=int, default=0)
args = parser.parse_args()


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)  # seperate dim channels into min(32, dim) groups

def norm_(dim):
    return nn.GroupNorm(min(dim, dim), dim)

def norm2(dim):
    return nn.GroupNorm(min(75, dim), dim)

class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)



def get_data_loaders(data_aug=False):
    if args.data_type == 'mnist':
        if data_aug:
            transform_train = transforms.Compose([
                transforms.RandomCrop(28, padding=4),
                transforms.ToTensor(),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
            ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
        ])

        train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=args.batch_size, shuffle=True, num_workers=2, drop_last=True
        )

        train_eval_loader = DataLoader(
            datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
            batch_size=args.test_batch_size, shuffle=False, num_workers=2, drop_last=True
        )

        test_loader = DataLoader(
            datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),batch_size=args.test_batch_size, shuffle=False, num_workers=2, drop_last=True
        )
    elif args.data_type == 'svhn':
        if data_aug:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.ToTensor(),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
            ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
        ])

        train_loader = DataLoader(datasets.SVHN(root='.data/svhn', split = 'train', download=True, transform=transform_train), batch_size=args.batch_size,
            shuffle=True, num_workers=2, drop_last=True
        )

        train_eval_loader = DataLoader(datasets.SVHN(root='.data/svhn', split = 'train', download=True, transform=transform_test),
            batch_size=args.test_batch_size, shuffle=False, num_workers=2, drop_last=True
        )

        test_loader = DataLoader(datasets.SVHN(root='.data/svhn', split = 'test', download=True, transform=transform_test),
            batch_size=args.test_batch_size, shuffle=False, num_workers=2, drop_last=True
        )
    else:
        if data_aug:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.ToTensor(),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
            ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
        ])

        train_loader = DataLoader(
            datasets.CIFAR10(root='.data/cifar10', train=True, download=True, transform=transform_train), batch_size=args.batch_size,
            shuffle=True, num_workers=2, drop_last=True
        )

        train_eval_loader = DataLoader(
            datasets.CIFAR10(root='.data/cifar10', train=True, download=True, transform=transform_test),
            batch_size=args.test_batch_size, shuffle=False, num_workers=2, drop_last=True
        )

        test_loader = DataLoader(
            datasets.CIFAR10(root='.data/cifar10', train=False, download=True, transform=transform_test),
            batch_size=args.batch_size, shuffle=False, num_workers=2, drop_last=True
        )    
    

    return train_loader, test_loader, train_eval_loader


def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()



def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)


def accuracy(down_model, feature_fc_model, dataset_loader, rand_x_t_pairs):
    total_correct = 0
    class_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)
        target_class = np.argmax(y, axis=1)
        down_result = down_model(x)

        added_down_logit_test = torch.empty((args.test_batch_size, len(down_result[0])+3, len(down_result[0][0]), len(down_result[0][0]))).to(device)
        for i in range(len(down_result)):
            added_down_logit_test[i] = torch.cat([down_result[i], rand_x_t_pairs]).to(device)
        result = feature_fc_model(added_down_logit_test)
        predicted_class = np.argmax(result.cpu().detach().numpy(), axis=1)
        correct = predicted_class == target_class
        correct_predict = predicted_class[correct]
        correct_target = target_class[correct]
        for i in range(0, 10):
            class_correct += np.sum(predicted_class == i)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset), result.cpu().detach().numpy(), target_class


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    logger.info(filepath)
    with open(filepath, "r") as f:
        logger.info(f.read())

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger


if __name__ == '__main__':

    makedirs(args.save)
    logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
    logger.info(args)

    device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

    if args.data_type == 'mnist' :
        downsampling_layers = [nn.Conv2d(1, 64, 3, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), norm(64), nn.ReLU(inplace=True)]
        PDE_layers = [nn.Conv2d(67, 67, 3, 1, 1), norm_(67), nn.ReLU(inplace=True), nn.Conv2d(67, 64, 3, 1, 1), norm(64), nn.ReLU(inplace=True)]
        fc_layers = [nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]
    else :
        downsampling_layers = [nn.Conv2d(3, 64, 3, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), norm(64), nn.ReLU(inplace=True)]
        PDE_layers = [nn.Conv2d(67, 67, 3, 1, 1), norm_(67), nn.ReLU(inplace=True), nn.Conv2d(67, 64, 3, 1, 1), norm(64), nn.ReLU(inplace=True)]
        fc_layers = [nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

    down_model = nn.Sequential(*downsampling_layers)
    feature_model = nn.Sequential(*PDE_layers)
    fc_model = nn.Sequential(*fc_layers)
    feature_fc_model = nn.Sequential(*PDE_layers, *fc_layers)

    model = nn.Sequential(*downsampling_layers, *PDE_layers, *fc_layers).to(device)

    if args.data_type=='mnist' :

        down_pixel_size = 6

        x_index = torch.tensor(np.array([[[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6],
                                            [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]]])/6.0,
                                    dtype=torch.float32).to(device)
        x_index2 = torch.tensor(np.array([[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3],
                                            [4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]])/6.0,
                                    dtype=torch.float32).to(device)
        t_index = torch.zeros((1, down_pixel_size, down_pixel_size), dtype=torch.float32).to(device)
        init_x_t_pairs = torch.cat([x_index, x_index2, t_index], dim=0).to(device)

    else:
        down_pixel_size = 7

        x_index = torch.tensor(np.array([[[1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7],
                                            [1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7]]])/7.0,
                                    dtype=torch.float32).to(device)

        x_index2 = torch.tensor(np.array([[[1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3, 3],
                                            [4, 4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7, 7]]])/7.0,
                                    dtype=torch.float32).to(device)
        t_index = torch.zeros((1, down_pixel_size, down_pixel_size), dtype=torch.float32).to(device)
        init_x_t_pairs = torch.cat([x_index, x_index2, t_index], dim=0).to(device)



    logger.info(model)
    print(logger.info(model))
    logger.info('Number of parameters: {}'.format(count_parameters(model)))

    train_loader, test_loader, train_eval_loader = get_data_loaders(args.data_aug)
  
    data_gen = inf_generator(train_loader)

    output_logits = []
    output_test_logits=[]
    target = []
    test_target = []

    if args.data_type=='mnist':
        ckpt = torch.load(os.path.join(args.save, 'model_pde_%s_%s_%s.pth'%(args.data_type, args.rectifier, str(args.nums))))
    elif args.data_type=='svhn':
        ckpt = torch.load(os.path.join(args.save, 'model_pde_%s_%s_%s_%s_%s_%s.pth'%(args.data_type, args.divide, args.lr, args.weight_decay, args.rectifier, str(args.nums))))
    else:
        ckpt = torch.load(os.path.join(args.save, 'model_pde_%s_%s.pth'%(args.rectifier, str(args.nums))))

            
    rand_x_index = ckpt['rand_x_index']
    rand_x_index2 = ckpt['rand_x_index2']
    rand_t_index = ckpt['rand_t_index']
    model.load_state_dict(ckpt['state_dict'])
  

    rand_x_t_pairs = torch.cat([rand_x_index, rand_x_index2, rand_t_index], dim=0).to(device)

    for i in range(len(train_loader)):
        x, y = data_gen.__next__()
        x = x.to(device)
        y = y.to(device)

        with torch.no_grad():
            down_logits = down_model(x)

            added_down_logit_init = torch.empty((len(down_logits), len(down_logits[0])+3, len(down_logits[0][0]), len(down_logits[0][0]))).to(device)
            for i in range(len(down_logits)):
                added_down_logit_init[i] = torch.cat([down_logits[i], init_x_t_pairs]).to(device)

            u_init = feature_model(added_down_logit_init)

            added_down_logit_rand = torch.empty((len(down_logits), len(down_logits[0])+3, len(down_logits[0][0]), len(down_logits[0][0]))).to(device)
            for i in range(len(down_logits)):
                added_down_logit_rand[i] = torch.cat([down_logits[i], rand_x_t_pairs]).to(device)

            u = feature_model(added_down_logit_rand)

            logits = fc_model(u)

            y = one_hot(np.array(y.cpu().numpy()), 10)
            target_class = np.argmax(y, axis=1)
            output_logits.extend(logits.cpu().numpy())
            target.extend(target_class)

            train_acc, _, _ = accuracy(down_model, feature_fc_model, train_eval_loader, rand_x_t_pairs)
            val_acc, test_predict_class, test_target_class = accuracy(down_model, feature_fc_model, test_loader, rand_x_t_pairs)

            print(train_acc, val_acc)
                        
            output_test_logits.extend(test_predict_class)
            test_target.extend(test_target_class)


    from tsne import bh_sne

    output = np.array(output_logits, dtype ='float64')
    output_2d = bh_sne(output)
    plt.rcParams['figure.figsize'] = (3, 3)
    plt.xlim([-45, 45])
    plt.ylim([-45, 45])
    cmap = plt.cm.coolwarm

    plt.scatter(output_2d[:, 0], output_2d[:, 1], c = cmap(np.divide(target, 10)), s = 0.7)

    plt.savefig('T-SNE_%s_%s_%s_%s_%s_pde_train_cmap_output.jpg'%(args.data_type, args.divide, args.lr, args.weight_decay, args.rectifier), bbox_inches='tight')
    plt.savefig('T-SNE_%s_%s_%s_%s_%s_pde_train_cmap_output.pdf'%(args.data_type, args.divide, args.lr, args.weight_decay, args.rectifier), bbox_inches='tight')

    plt.clf()
    plt.rcParams['figure.figsize'] = (3, 3)
    plt.xlim([-45, 45])
    plt.ylim([-45, 45])

    plt.scatter(output_2d[:, 0], output_2d[:, 1], c = target, s = 0.7)

    plt.savefig('T-SNE_%s_%s_%s_%s_%s_pde_train_output.jpg'%(args.data_type, args.divide, args.lr, args.weight_decay, args.rectifier), bbox_inches='tight')
    plt.savefig('T-SNE_%s_%s_%s_%s_%s_pde_train_output.pdf'%(args.data_type, args.divide, args.lr, args.weight_decay, args.rectifier), bbox_inches='tight')

    plt.clf()
    output_test = np.array(output_test_logits, dtype='float64')
    output_test_2d = bh_sne(output_test)
    plt.rcParams['figure.figsize'] = (3, 3)
    plt.xlim([-45, 45])
    plt.ylim([-45, 45])
    plt.scatter(output_test_2d[:, 0], output_test_2d[:, 1], c = cmap(np.divide(target, 10)), s = 0.7)
    plt.savefig('T-SNE_%s_%s_%s_%s_%s_pde_test_cmap_output.jpg'%(args.data_type, args.divide, args.lr, args.weight_decay, args.rectifier), bbox_inches='tight')
    plt.savefig('T-SNE_%s_%s_%s_%s_%s_pde_test_cmap_output.pdf'%(args.data_type, args.divide, args.lr, args.weight_decay, args.rectifier), bbox_inches='tight')


    plt.clf()
    plt.rcParams['figure.figsize'] = (3, 3)
    plt.xlim([-45, 45])
    plt.ylim([-45, 45])
    plt.scatter(output_test_2d[:, 0], output_test_2d[:, 1], c = test_target, s = 0.7)
    plt.savefig('T-SNE_%s_%s_%s_%s_%s_pde_test_output.jpg'%(args.data_type, args.divide, args.lr, args.weight_decay, args.rectifier), bbox_inches='tight')
    plt.savefig('T-SNE_%s_%s_%s_%s_%s_pde_test_output.pdf'%(args.data_type, args.divide, args.lr, args.weight_decay, args.rectifier), bbox_inches='tight')


