import os
import argparse
import logging
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import grad, Variable
import time
import random
import matplotlib.pyplot as plt
from torchvision.utils import save_image

parser = argparse.ArgumentParser()
parser.add_argument('--nepochs', type=int, default=10)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)
parser.add_argument('--save', type=str, default='./experiment1')
parser.add_argument('--rectifier', type=str, default='ReLU')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)


def norm_(dim):
    return nn.GroupNorm(min(dim, dim), dim)


def norm2(dim):
    return nn.GroupNorm(min(75, dim), dim)


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000):
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(28, padding=4),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
        shuffle=True, num_workers=2, drop_last=True
    )

    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
        batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
    )

    return train_loader, test_loader, train_eval_loader


def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates):
    initial_learning_rate = args.lr * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn


def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)


def accuracy(down_model, feature_fc_model, dataset_loader, rand_x_t_pairs):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)
        target_class = np.argmax(y, axis=1)
        down_result = down_model(x)
        added_down_logit_test = torch.empty(
            (args.test_batch_size, len(down_result[0]) + 3, len(down_result[0][0]), len(down_result[0][0]))).to(device)
        for i in range(len(down_result)):
            added_down_logit_test[i] = torch.cat([down_result[i], rand_x_t_pairs]).to(device)
        result = feature_fc_model(added_down_logit_test)
        predicted_class = np.argmax(result.cpu().detach().numpy(), axis=1)
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    logger.info(filepath)
    with open(filepath, "r") as f:
        logger.info(f.read())

    for f in package_files:
        logger.info(f)
        with open(f, "r") as package_f:
            logger.info(package_f.read())

    return logger


if __name__ == '__main__':

    makedirs(args.save)
    logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
    logger.info(args)

    device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

    downsampling_layers = [nn.Conv2d(1, 64, 3, 1), norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1),
                           norm(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 4, 2, 1), norm(64), nn.ReLU(inplace=True)]
    fc_layers = [nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]

    if args.rectifier == 'ReLU':
        PDE_layers = [nn.Conv2d(67, 67, 3, 1, 1), norm_(67), nn.ReLU(inplace=True), nn.Conv2d(67, 64, 3, 1, 1),
                      norm(64),
                      nn.ReLU(inplace=True)]
    elif args.rectifier == 'CELU':
        PDE_layers = [nn.Conv2d(67, 67, 3, 1, 1), norm_(67), nn.CELU(inplace=True), nn.Conv2d(67, 64, 3, 1, 1),
                      norm(64),
                      nn.CELU(inplace=True)]
    else:
        print('Clarify Activation Function')

    down_model = nn.Sequential(*downsampling_layers)
    feature_model = nn.Sequential(*PDE_layers)
    fc_model = nn.Sequential(*fc_layers)
    feature_fc_model = nn.Sequential(feature_model, fc_model)

    model = nn.Sequential(down_model, feature_model, fc_model).to(device)
    model_fe = nn.Sequential(feature_model).to(device)

    criterion = nn.CrossEntropyLoss().to(device)

    train_loader, test_loader, train_eval_loader = get_mnist_loaders(
        args.data_aug, args.batch_size, args.test_batch_size
    )

    data_gen = inf_generator(train_loader)
    batches_per_epoch = len(train_loader)

    lr_fn = learning_rate_with_decay(
        args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
        decay_rates=[1, 0.1, 0.01, 0.001]
    )

    down_pixel_size = 6

    x_index = torch.tensor(np.array([[[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6],
                                      [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]]]) / 6.0,
                           dtype=torch.float32).to(device)

    x_index2 = torch.tensor(np.array([[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3],
                                       [4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]]) / 6.0,
                            dtype=torch.float32).to(device)
    t_index = torch.zeros((1, down_pixel_size, down_pixel_size), dtype=torch.float32).to(device)
    init_x_t_pairs = torch.cat([x_index, x_index2, t_index], dim=0).to(device)

    rand_x_index = torch.tensor(np.array([[[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6],
                                           [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]]]) / 6.0,
                                dtype=torch.float32, requires_grad=True)

    rand_x_index2 = torch.tensor(np.array([[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3],
                                            [4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]]) / 6.0,
                                 dtype=torch.float32, requires_grad=True)

    rand_t_index = torch.ones((1, down_pixel_size, down_pixel_size), dtype=torch.float32, requires_grad=True)

    for num in ['00', '10', '20', '30', '01', '11', '21', '31', '02', '12', '22', '32', 'b01', 'b11', 'b21', 'b31',
                'b02', 'b12', 'b22', 'b32']:
        exec("a_%s = torch.tensor(1, dtype=torch.float32, requires_grad=True)" % num)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    optimizer_fe = torch.optim.Adam(model_fe.parameters(), lr=args.lr)
    optimizer_para = torch.optim.Adam([a_00, a_10, a_20, a_30, a_01, a_11, a_21, a_31, a_02, a_22, a_12, a_32, a_b01,
                                       a_b11, a_b21, a_b31, a_b02, a_b22, a_b12, a_b32], lr=args.lr)
    optimizer_x_t = torch.optim.Adam([rand_x_index, rand_x_index2, rand_t_index], lr=args.lr)

    logger.info(model)
    print(logger.info(model))
    logger.info('Number of parameters: {}'.format(count_parameters(model)))

    best_acc = 0
    batch_time_meter = RunningAverageMeter()
    f_nfe_meter = RunningAverageMeter()
    b_nfe_meter = RunningAverageMeter()
    end = time.time()

    epo = 0

    for itr in range(args.nepochs * batches_per_epoch):

        rand_x_t_pairs = torch.cat([rand_x_index, rand_x_index2, rand_t_index], dim=0).to(device)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_fn(itr)

        x, y = data_gen.__next__()
        x = x.to(device)
        y = y.to(device)

        down_logits = down_model(x)

        added_down_logit_init = torch.empty(
            (len(down_logits), len(down_logits[0]) + 3, len(down_logits[0][0]), len(down_logits[0][0]))).to(device)
        for i in range(len(down_logits)):
            added_down_logit_init[i] = torch.cat([down_logits[i], init_x_t_pairs]).to(device)

        u_init = feature_model(added_down_logit_init)
        init_criterion = nn.MSELoss()
        loss_init = init_criterion(down_logits, u_init)

        added_down_logit_rand = torch.empty(
            (len(down_logits), len(down_logits[0]) + 3, len(down_logits[0][0]), len(down_logits[0][0]))).to(device)
        for i in range(len(down_logits)):
            added_down_logit_rand[i] = torch.cat([down_logits[i], rand_x_t_pairs]).to(device)

        u = feature_model(added_down_logit_rand)
        u_2 = torch.pow(u, 2)
        u_3 = torch.pow(u, 3)

        u_x_sum = a_01.to(device) + torch.mul(a_11.to(device), u) + torch.mul(a_21.to(device), u_2) + torch.mul(
            a_31.to(device), u_3)
        u_x2_sum = a_b01.to(device) + torch.mul(a_b11.to(device), u) + torch.mul(a_b21.to(device), u_2) + torch.mul(
            a_b31.to(device), u_3)
        basic_sum = a_00.to(device) + torch.mul(a_10.to(device), u) + torch.mul(a_20.to(device), u_2) + torch.mul(
            a_30.to(device), u_3)

        u_t = grad(u, rand_t_index, grad_outputs=u.data.new(u.shape).fill_(1), create_graph=True)[0].to(device)
        u_x = grad(u, rand_x_index, grad_outputs=u_x_sum, create_graph=True)[0].to(device)
        u_x2 = grad(u, rand_x_index2, grad_outputs=u_x2_sum, create_graph=True)[0].to(device)
        ux = basic_sum.sum(0).sum(0) + u_x + u_x2

        dyn_criterion = nn.MSELoss()
        loss_dyn = dyn_criterion(u_t, ux) * (1.0 / (128 * 64))

        logits = fc_model(u)

        loss_task = criterion(logits, y)

        if itr % 1 == 0:
            loss_init.backward(retain_graph=True)
            loss_dyn.backward(retain_graph=True)
            optimizer_fe.step()
            optimizer_para.step()
            optimizer.zero_grad()
            optimizer_fe.zero_grad()
            optimizer_x_t.zero_grad()
            optimizer_para.zero_grad()

        loss_task.backward(retain_graph=True)
        optimizer.step()
        optimizer_x_t.step()
        optimizer.zero_grad()
        optimizer_fe.zero_grad()
        optimizer_x_t.zero_grad()
        optimizer_para.zero_grad()



        batch_time_meter.update(time.time() - end)

        if itr % batches_per_epoch == 0:
            epo += 1
            start = time.time()
            with torch.no_grad():
                train_acc = accuracy(down_model, feature_fc_model, train_eval_loader, rand_x_t_pairs)
                val_acc = accuracy(down_model, feature_fc_model, test_loader, rand_x_t_pairs)

                time_measure = time.time() - start

                if val_acc > best_acc:
                    torch.save({'state_dict': model.state_dict(), 'args': args, 'a_00':a_00, 'a_10' : a_10, 'a_20':a_20,
                                'a_30':a_30, 'a_01':a_01, 'a_11':a_11, 'a_21':a_21, 'a_31':a_31, 'a_02':a_02, 'a_12':a_12,
                                'a_22':a_22, 'a_32':a_32, 'b01':a_b01, 'b11':a_b11, 'b21':a_b21, 'b31':a_b31, 'b02':a_b02,
                                'b12':a_b12, 'b22':a_b22, 'b32':a_b32, 'rand_x_index':rand_x_index,
                                'rand_x_index2':rand_x_index2, 'rand_t_index':rand_t_index},os.path.join(args.save, 'model_pde_loss_curve_%s_normalization.pth'%args.rectifier))
                    torch.save({'state_dict': model.state_dict(), 'args': args},
                        os.path.join(args.save, 'model_pde_loss_curve_only_dict_%s_normalization.pth'%args.rectifier))
                    best_acc = val_acc

                torch.save({'state_dict': model.state_dict(), 'args': args, 'a_00': a_00, 'a_10': a_10, 'a_20': a_20,
                            'a_30': a_30, 'a_01': a_01, 'a_11': a_11, 'a_21': a_21, 'a_31': a_31, 'a_02': a_02,
                            'a_12': a_12,
                            'a_22': a_22, 'a_32': a_32, 'b01': a_b01, 'b11': a_b11, 'b21': a_b21, 'b31': a_b31,
                            'b02': a_b02,
                            'b12': a_b12, 'b22': a_b22, 'b32': a_b32, 'rand_x_index': rand_x_index,
                            'rand_t_index': rand_t_index, 'rand_x_index2': rand_x_index2},
                           os.path.join(args.save, 'model_pde_loss_curve_%s_recent_normalization.pth' % args.rectifier))
                torch.save({'state_dict': model.state_dict(), 'args': args},
                    os.path.join(args.save, 'model_pde_loss_curve_only_dict_%s_recent_normalization.pth' % args.rectifier))

                logger.info(
                    "Epoch {:04d} | Time {:.3f} ({:.3f}) | Train Acc {:.4f} | Test Acc {:.4f} | Time_measure {:.4f}".format(
                        itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, train_acc, val_acc,
                        time_measure)
                )


    plt.plot(range(args.nepochs * batches_per_epoch)+1, len(init_loss))
    plt.plot(range(args.nepochs * batches_per_epoch)+1,dyn_loss)
    plt.plot(range(args.nepochs * batches_per_epoch)+1,task_loss)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title('MNIST loss graph')
    plt.legend(['Init', 'Dyn', 'Task'])
    plt.savefig('loss_curve_of_pde_mnist.pdf')









