# from https://github.com/pytorch/examples/blob/master/mnist/main.py#L10

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import data_proc as dp
from PIL import Image
import itertools
import math
import numpy as np
import sys
sys.path.append('topology-decision-boundaries/src')
import TopologicalData
import random
import ripser
import clustering
import matplotlib.pyplot as plt
import scipy


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        '''if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break'''


def run_network(model, transform):
    data, labels = dp.load_data()
    img = data[0]
    img = Image.fromarray(img.numpy(), mode='L')
    img = transform(img)

    print(model.forward(img[None, ...]))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    percentage_correct = 100. * correct / len(test_loader.dataset)
    print('End of epoch: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset), percentage_correct))

    return test_loss, percentage_correct


def trainer(combin, seed, mode='train'):

    idx1 = combin[0]
    idx2 = combin[1]
    # Training settings

    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=5, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'batch_size': args.batch_size}
    if use_cuda:
        kwargs.update({'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True})

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

    # dataset1 = datasets.MNIST('data', train=True, transform=transform)
    # dataset2 = datasets.MNIST('data', train=False, transform=transform)
    # dataset1 = dp.two_MNIST(transform, idx1, idx2, 'No')
    # dataset2 = dp.two_MNIST(transform, idx1, idx2)

    if mode == 'train':

        dataset1 = dp.two_MNIST(transform, idx1, idx2, 'No')
        train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)

        model = Net().to(device)
        optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

        scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
        for epoch in range(1, args.epochs + 1):
            train(args, model, device, train_loader, optimizer, epoch)
            test(model, device, train_loader)
            scheduler.step()

        file = 'models/full_data/mnist_cnn_' + str(idx1) + str(idx2) + '_' + str(seed) + '.pt'
        torch.save(model.state_dict(), file)

    else:
        # combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))
        combin = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9]]
        # r = np.load('clusters/r_values_D1_10_pts_10_iter_full.npy')

        model_idx1 = idx1
        model_idx2 = idx2

        filename = 'models/full_data/mnist_cnn_' + str(model_idx1) + str(model_idx2) + '_' + str(seed) + '.pt'
        model = Net()
        model.load_state_dict(torch.load(filename))

        loss = []
        pc = []

        for seed in range(1):
            for labels in combin:
                task_idx1 = labels[0]
                task_idx2 = labels[1]
                dataset = dp.two_MNIST(transform, task_idx1, task_idx2, 'No')

                test_loader = torch.utils.data.DataLoader(dataset, **kwargs)
                test_loss, percentage_correct = test(model, device, test_loader)

                loss.append(test_loss)
                pc.append(percentage_correct)

        return loss, pc

def compute_true_boundaries():
    combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))
    for indices in combin:
        idx1 = indices[0]
        idx2 = indices[1]

        data, labels = dp.load_data(idx1, idx2, 'No')
        db = []
        for i in range(len(data)):
            img = data[i]
            true_label = abs(labels[i].numpy()-1)
            db_point = np.append(img.flatten().numpy(), true_label)

            db.append(db_point)

        decision_boundary = np.array(db)
        filename = 'boundaries/full_data/true_boundary_' + str(idx1) + str(idx2) + '.npy'
        np.save(filename, decision_boundary)


def compute_decision_boundaries():
    combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))

    for seed in range(3):
        for indices in combin:

            print("Computing DB for Seed: " + str(seed) + ", digits: " + str(indices))

            idx1 = indices[0]
            idx2 = indices[1]

            filename = 'models/full_data/mnist_cnn_' + str(idx1) + str(idx2) + '_' + str(seed) + '.pt'
            model = Net()
            model.load_state_dict(torch.load(filename))

            data, labels = dp.load_data(idx1, idx2, 'No')

            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])

            db = []

            for i in range(len(data)):
                img = data[i]

                img = Image.fromarray(img.numpy(), mode='L')
                img = transform(img)

                result = math.exp(model.forward(img[None, ...]).detach().numpy()[0][0])
                db_point = np.append(data[i].flatten().numpy(), result)
                db.append(db_point)

            decision_boundary = np.array(db)

            filename = 'boundaries/full_data/notransform_mnist_cnn_' + str(idx1) + str(idx2) + '_' + str(seed) + '.npy'
            np.save(filename, decision_boundary)


def train_all():
    combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))

    for i in range(3):
        for labels in combin:
            print("\nSeed: " + str(i) + ", digits: " + str(labels))
            trainer(labels, i, mode='train')


def most_pers(dgm, n):
    # truncates diagram dgm to the n most persistent points

    dgm = sorted(dgm, key=lambda point: point[0] - point[1])
    dgm = dgm[:n]

    return dgm


def compute_diagrams(num_filt=100, maxdim=1, num_points=10, seed=0):

    # combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))
    combin = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9]]

    D = []

    for labels in combin:
        print("Computing diagram for Seed: " + str(seed) + ", digits: " + str(labels))

        idx1 = labels[0]
        idx2 = labels[1]

        filename = 'boundaries/full_data/notransform_mnist_cnn_' + str(idx1) + str(idx2) + '_' + str(seed) + '.npy'
        db = np.load(filename)

        pd1 = diagram(db, num_filt, maxdim)
        pd1 = most_pers(pd1, num_points)

        D.append(pd1)

    # for labels in combin:
        print("Computing diagram for true boundary, digits: " + str(labels))
        idx1 = labels[0]
        idx2 = labels[1]

        filename = 'boundaries/full_data/true_boundary_' + str(idx1) + str(idx2) + '.npy'
        db = np.load(filename)

        pd1 = diagram(db, num_filt, maxdim)
        pd1 = most_pers(pd1, num_points)
        D.append(pd1)

    return D


def diagram(db, num_filt, maxdim):

    x = db[:, :-1]
    y = db[:, -1]

    idx = np.arange(math.floor(0.1 * len(x)))
    x = x[idx] / 255
    y = np.rint(y[idx])

    data = np.append(x, y.reshape(-1, 1), axis=1)[range(math.floor(0.5*len(x)))]

    rips = ripser.ripser(data, maxdim=maxdim)

    pd1 = rips['dgms'][1]

    return pd1


def cluster(D, k=10, max_iter=10, verbose=True):
    print("Clustering...")
    r, M = clustering.pd_fuzzy(D, k, verbose=verbose, max_iter=max_iter)

    # np.save('clusters/r_values_D1_true_0_only.npy', r)
    # np.save('clusters/M_values_D1_true_0_only.npy', M)

    # print(r)
    return r, M


def compute_taskwise_loss():

    combin = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9]]

    for seed in range(1, 3):
        tot_loss, tot_pc = [], []
        for labels in combin:
            loss, pc = trainer(labels, seed, mode='test')
            tot_loss.append(np.array(loss))
            tot_pc.append(np.array(pc))

        print(tot_loss)
        print(tot_pc)

        np.save('results/loss_model_0vAll_seed_' + str(seed) + '.npy', np.array(tot_loss))
        np.save('results/pc_model_0vAll_seed_' + str(seed) + '.npy', np.array(tot_pc))


def mean_performance(r=None, do_print=False):
    combin = [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [0, 8], [0, 9]]
    pc = []
    task_to_clust = [[], [], []]

    load_r = False
    if r is None:
        r = []
        load_r = True

    for seed in range(3):
        pc_temp = np.load('results/pc_model_0vAll_seed_' + str(seed) + '.npy')
        pc.append(pc_temp)
        if load_r:
            r.append(np.load('results/r_mnist_seed_' + str(seed) + '.npy'))
        for i in range(9):
            task_to_clust[seed].append(r[seed][2 * i].argsort()[::-1][0])

    # ALL

    results = np.zeros((9, 27))
    for j in range(9):
        for seed in range(3):
            for i in range(9):
                results[j][seed*9+i] = pc[seed][i][j]

    # TOP 3

    # task_to_clust[seed][task] gives the cluster centre assigned to the task
    # we want to get the top-3 models assigned to each task
    # so start by fixing seed and task
    results_top3 = np.zeros((9, 9))
    for seed in range(3):
        for j in range(9):
            centre = task_to_clust[seed][j]
            # centre corresponds to the centre for task j (i.e., 0v(j+1))
            # now we want the top 3 models assigned to this centre

            cand = []
            for i in range(9):
                cand.append(r[seed][i*2][centre])
            cand = np.array(cand)
            top3_idx = cand.argsort()[-3:][::-1]

            # then top3_idx gives the model with the largest assignment to each task
            # so we simply need the average of pc[top3_idx][j] over all seeds

            for i in range(len(top3_idx)):
                results_top3[j][seed*3+i] = pc[seed][top3_idx[i]][j]

    # TOP 2

    # task_to_clust[seed][task] gives the cluster centre assigned to the task
    # we want to get the top-3 models assigned to each task
    # so start by fixing seed and task
    results_top2 = np.zeros((9, 6))
    for seed in range(3):
        for j in range(9):
            centre = task_to_clust[seed][j]
            # centre corresponds to the centre for task j (i.e., 0v(j+1))
            # now we want the top 3 models assigned to this centre

            cand = []
            for i in range(9):
                cand.append(r[seed][i * 2][centre])
            cand = np.array(cand)
            top2_idx = cand.argsort()[-2:][::-1]

            # then top3_idx gives the model with the largest assignment to each task
            # so we simply need the average of pc[top3_idx][j] over all seeds

            for i in range(len(top2_idx)):
                results_top2[j][seed * 2 + i] = pc[seed][top2_idx[i]][j]

    # BOTTOM 3

    results_bot3 = np.zeros((9, 9))
    for seed in range(3):
        for j in range(9):
            centre = task_to_clust[seed][j]

            cand = []
            for i in range(9):
                cand.append(r[seed][i * 2][centre])
            cand = np.array(cand)
            bot3_idx = cand.argsort()[:3]

            for i in range(len(bot3_idx)):
                results_bot3[j][seed * 3 + i] = pc[seed][bot3_idx[i]][j]

    # TOP 1

    # task_to_clust[seed][task] gives the cluster centre assigned to the task
    # we want to get the top-3 models assigned to each task
    # so start by fixing seed and task
    results_top1 = np.zeros((9, 3))
    for seed in range(3):
        for j in range(9):
            centre = task_to_clust[seed][j]
            # centre corresponds to the centre for task j (i.e., 0v(j+1))
            # now we want the top 3 models assigned to this centre

            cand = []
            for i in range(9):
                cand.append(r[seed][i * 2][centre])
            cand = np.array(cand)
            top1_idx = cand.argsort()[-1:][::-1]

            # then top3_idx gives the model with the largest assignment to each task
            # so we simply need the average of pc[top3_idx][j] over all seeds

            for i in range(len(top1_idx)):
                results_top1[j][seed + i] = pc[seed][top1_idx[i]][j]

    # RESULTS

    diff_top1_avg = []
    diff_top2_avg = []
    diff_top3_avg = []
    diff_top3_bot3 = []

    diff_top1_avg_pc = []
    diff_top2_avg_pc = []
    diff_top3_avg_pc = []

    for i in range(9):
        mean_bot3 = np.mean(results_bot3[i])
        sem_bot3 = scipy.stats.sem(results_bot3[i])

        mean = np.mean(results[i])
        sem = scipy.stats.sem(results[i])

        mean_top3 = np.mean(results_top3[i])
        sem_top3 = scipy.stats.sem(results_top3[i])

        mean_top2 = np.mean(results_top2[i])
        sem_top2 = scipy.stats.sem(results_top2[i])

        mean_top1 = np.mean(results_top1[i])
        sem_top1 = scipy.stats.sem(results_top1[i])

        if do_print:
            print('Task bot3: 0v' + str(i + 1) + ', mean = ' + str(mean_bot3) + ' +- ' + str(sem_bot3))
            print('Task all : 0v' + str(i + 1) + ', mean = ' + str(mean) + ' +- ' + str(sem))
            print('Task top3: 0v' + str(i + 1) + ', mean = ' + str(mean_top3) + ' +- ' + str(sem_top3))

            print(mean_bot3)
            print(mean)
            print(mean_top3)
            print(mean_top2)
            print(mean_top1)

        diff_top1_avg.append(mean_top1 - mean)
        diff_top2_avg.append(mean_top2 - mean)
        diff_top3_avg.append(mean_top3 - mean)
        diff_top3_bot3.append(mean_top3 - mean_bot3)

        diff_top1_avg_pc.append((mean_top1 - mean)/mean * 100)
        diff_top2_avg_pc.append((mean_top2 - mean) / mean * 100)
        diff_top3_avg_pc.append((mean_top3 - mean) / mean * 100)

    print('Top3 vs avg: ' + str(np.mean(np.array(diff_top3_avg))) + ' +- ' + str(scipy.stats.sem(diff_top3_avg)))
    print('Top2 vs avg: ' + str(np.mean(np.array(diff_top2_avg))) + ' +- ' + str(scipy.stats.sem(diff_top2_avg)))
    print('Top1 vs avg: ' +str(np.mean(np.array(diff_top1_avg))) + ' +- ' + str(scipy.stats.sem(diff_top1_avg)))

    print('Top3 vs avg %: ' + str(np.mean(np.array(diff_top3_avg_pc))) + ' +- ' + str(scipy.stats.sem(diff_top3_avg_pc)))
    print('Top2 vs avg %: ' + str(np.mean(np.array(diff_top2_avg_pc))) + ' +- ' + str(scipy.stats.sem(diff_top2_avg_pc)))
    print('Top1 vs avg %: ' + str(np.mean(np.array(diff_top1_avg_pc))) + ' +- ' + str(scipy.stats.sem(diff_top1_avg_pc)))
    # print(diff_top3_avg)
    # print(diff_top2_avg)
    # print(diff_top1_avg)
    print()


def compute_clusters(iter_val=5, save=False, print_results=False):
    D = []
    for seed in range(3):
        D.append(compute_diagrams(num_filt=100, maxdim=1, num_points=21, seed=seed))

    for iterations in range(2, 20):
        r_accum = []
        # iterations = iter_val
        x_all = []
        y_all = []
        for seed in range(3):

            # for iterations in range(1, 20):
            # iterations = 20

            r, M = cluster(D[seed], k=9, max_iter=iterations, verbose=False)

            if save:
                np.save('results/r_mnist_seed_' + str(seed) + '.npy', r)

            # converges around max_iter=20
            # np.save('results/r_values_10_iter_conv.npy', r)

            if print_results:
                print(np.around(r, 3))

                for j in range(len(r)):
                    print(r[j].argsort()[-3:][::-1])

            r_accum.append(r)

        print("ITERATIONS: " + str(iterations))
        print(mean_performance(r_accum))
    return r_accum, M


'''def old_code():
    pc = np.load('results/pc_model_0vAll_seed_' + str(seed) + '.npy')

            # r = np.load('results/r_values_10_iter_conv.npy')
            # print(np.round(pc, 2))
            # print(np.round(r, 2))
            # for i in range(len(r)):
            #     print(r[i].argsort()[::-1])

            task_to_clust = []
            for i in range(9):
                task_to_clust.append(r[2*i].argsort()[::-1][0])

            x, y = [], []

            for i in range(9):
                for j in range(9):
                    temp_r = r[2 * i][task_to_clust[j]]
                    temp_pc = pc[i][j]
                    # print('percentage: ' + str(temp_pc) + ', membership: ' + str(temp_r))
                    x.append(temp_r)
                    y.append(temp_pc)

            x_all.append(x)
            y_all.append(y)

        # pear = scipy.stats.pearsonr(x, y)
        # print("Iterations: " + str(iterations) + ', Pearson: ' + str(pear))

        # plot(x, y, save=False)
        plot_all(x_all, y_all, iterations, seed)'''


def plot(x, y, save=False):
    plt.scatter(x, y, s=10)
    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)), color='r')
    plt.xlim(-0.05, 1.05)
    plt.ylim(45, 105)
    if save:
        plt.savefig('results/figure_' + str(iterations) + '_iter.png')
    else:
        plt.show()
    plt.close()


def plot_all(x_all, y_all, iterations, seed):
    plt.scatter(x_all[0], y_all[0], color='r', s=10)
    plt.scatter(x_all[1], y_all[1], color='g', s=10)
    plt.scatter(x_all[2], y_all[2], color='b', s=10)

    plt.xlim(-0.05, 1.05)
    plt.ylim(45, 105)

    x = np.array(x_all).flatten()
    y = np.array(y_all).flatten()

    plt.xlabel('Membership value')
    plt.ylabel('Task performance (%)')
    plt.yscale('log')

    pear = scipy.stats.pearsonr(x, y)
    print('Seed: ' + str(seed) + ', Pearson: ' + str(pear))

    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)), color='k')
    plt.show()
    # plt.savefig('results/all_seeds/figure_' + str(iterations) + '_iter.png')
    plt.close()


if __name__ == '__main__':

    train_all()
    compute_decision_boundaries()
    compute_true_boundaries()
    compute_taskwise_loss()
    compute_clusters()

    r, M = compute_clusters()

    # mean_performance(r=r)
