# https://github.com/pytorch/examples/issues/116

import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torch.optim as optim
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
from tqdm import tqdm
import random

from utils import weights_init, compute_gan_loss

import numpy as np


def generate_P(mode, size):
    result = torch.zeros((size, size))
    if mode == "all":
        result = torch.ones((size, size)) / size
    elif mode == "single":
        for i in range(size):
            result[i][i] = 1
    elif mode == "ring":
        for i in range(size):
            result[i][i] = 1 / 3
            result[i][(i - 1 + size) % size] = 1 / 3
            result[i][(i + 1) % size] = 1 / 3
    elif mode == "star":
        for i in range(size):
            result[i][i] = 1 - 1 / size
            result[0][i] = 1 / size
            result[i][0] = 1 / size
    elif mode == "meshgrid":
        assert size > 0
        i = int(np.sqrt(size))
        while size % i != 0:
            i -= 1
        shape = (i, size // i)
        nrow, ncol = shape
        print(shape, flush=True)
        topo = np.zeros((size, size))
        for i in range(size):
            topo[i][i] = 1.0
            if (i + 1) % ncol != 0:
                topo[i][i + 1] = 1.0
                topo[i + 1][i] = 1.0
            if i + ncol < size:
                topo[i][i + ncol] = 1.0
                topo[i + ncol][i] = 1.0
        topo_neighbor_with_self = [np.nonzero(topo[i])[0] for i in range(size)]
        for i in range(size):
            for j in topo_neighbor_with_self[i]:
                if i != j:
                    topo[i][j] = 1.0 / max(len(topo_neighbor_with_self[i]),
                                           len(topo_neighbor_with_self[j]))
            topo[i][i] = 2.0 - topo[i].sum()
        result = torch.tensor(topo, dtype=torch.float)
    elif mode == "exponential":
        x = np.array([1.0 if i & (i - 1) == 0 else 0 for i in range(size)])
        x /= x.sum()
        topo = np.empty((size, size))
        for i in range(size):
            topo[i] = np.roll(x, i)
        result = torch.tensor(topo, dtype=torch.float)
    print(result, flush=True)
    return result


def train_agda(dataset_list, manual_seed, options):

    model = options['model']
    loss = options['loss']
    data = options['data']
    lr = options['learning_rate']
    nz = options['nz']
    batch_size = options['batch_size']
    num_epochs = options['num_epochs']
    device = options['device']
    node = options['node']

    generator_list = []
    discriminator_list = []

    # Define gan networks
    if model == 'vgan':
        from vgan import VanillaDiscriminator, VanillaGenerator

        if data == 'mnist':
            for i in range(node):
                generator = VanillaGenerator(nz).to(device)
                discriminator = VanillaDiscriminator().to(device)
                generator_list.append(generator)
                discriminator_list.append(discriminator)

    for i in range(node):
        random.seed(manual_seed)
        torch.manual_seed(manual_seed)
        generator_list[i].apply(weights_init)
        # for name, param in generator_list[i].named_parameters():
        #     print(name, param.data[0])
        #     break
        discriminator_list[i].apply(weights_init)


    # optimizers
    optim_d_list = []
    optim_g_list = []
    for i in range(node):
        optim_g = optim.SGD(generator.parameters(), lr=lr)
        optim_d = optim.SGD(discriminator.parameters(), lr=lr)  # only takes in D's parameter
        optim_g_list.append(optim_g)
        optim_d_list.append(optim_d)

    # Initialize parameter saving
    gen_param = []
    dis_param = []

    Connect_matrix = generate_P(mode=options['mode'], size=node)

    print('Training......')

    for epoch in range(num_epochs):
        # Random selection dataloader
        train_loader_list = []
        for i in range(node):
            sampler = RandomSampler(dataset_list[i], replacement=True, num_samples=len(dataset_list[i]))
            train_loader = DataLoader(dataset_list[i], batch_sampler=BatchSampler(sampler, batch_size=batch_size, drop_last=False))
            train_loader_list.append(iter(train_loader))

        # Initialize parameter saving for this epoch
        epoch_gen_param = []
        epoch_dis_param = []

        batch_len = len(train_loader_list[0])

        for _ in range(batch_len):

            for i in range(node):
                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                # train with real
                discriminator_list[i].zero_grad()
                images, _ = next(train_loader_list[i])
                images = images.to(device)
                b_size = images.size()[0]
                label = torch.full((b_size,), 1, dtype=images.dtype, device=device)
                output = discriminator_list[i](images)
                # loss_real = - torch.mean(1 * torch.log(output + 1e-8))
                loss_real = compute_gan_loss(output, label, loss=loss)
                # loss_real = nn.BCELoss()(output, label)
                loss_real.backward()

                # train with fake
                noises = torch.randn(b_size, nz, device=device)
                images_fake = generator_list[i](noises)
                label.fill_(0)
                output = discriminator_list[i](images_fake.detach()) # Detach fake from the graph to save computation
                # loss_fake = - torch.mean(1 * torch.log(1 - output + 1e-8))
                loss_fake = compute_gan_loss(output, label, loss=loss)
                # loss_fake = nn.BCELoss()(output, label)
                loss_fake.backward()
                optim_d_list[i].step()

                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################
                generator_list[i].zero_grad()
                label.fill_(1)
                output = discriminator_list[i](images_fake)
                # loss_g = - torch.mean(1 * torch.log(output + 1e-8))
                loss_g = compute_gan_loss(output, label, loss=loss)
                # loss_g = nn.BCELoss()(output, label)
                loss_g.backward()
                optim_g_list[i].step()
            
            ## communication step ##
            generator_dict_list = []
            discriminator_dict_list = []

            for i in range(node):
                generator_dict_list.append(generator_list[i].state_dict())
                discriminator_dict_list.append(discriminator_list[i].state_dict())

            for i in range(node):
                for name, param in generator_list[i].named_parameters():
                    param.data = torch.zeros_like(param.data, device=device, requires_grad=True)

                    for j in range(node):
                        param.data += Connect_matrix[i][j] * generator_dict_list[j][name].data

                for name, param in discriminator_list[i].named_parameters():
                    param.data = torch.zeros_like(param.data)

                    for j in range(node):
                        param.data += Connect_matrix[i][j] * discriminator_dict_list[j][name].data
            
        ## save parameters ##
        ## communication step ##
        generator_dict_list = []
        discriminator_dict_list = []

        for i in range(node):
            generator_dict_list.append(generator_list[i].state_dict())
            discriminator_dict_list.append(discriminator_list[i].state_dict())
        
        ## get the average of parameters
        # for i in range(node):
        for name, param in generator_list[0].named_parameters():
            data = torch.zeros_like(param.data, device=device, requires_grad=False)

            for j in range(node):
                data += generator_dict_list[j][name].data
            
            data /= node

            epoch_gen_param.append(data)

        for name, param in discriminator_list[0].named_parameters():
            data = torch.zeros_like(param.data, device=device, requires_grad=False)

            for j in range(node):
                data += discriminator_dict_list[j][name].data
            
            data /= node

            epoch_dis_param.append(data)
        
        gen_param.append(epoch_gen_param)
        dis_param.append(epoch_dis_param)


    return gen_param, dis_param

if __name__ == '__main__':
    manual_seed = 123
    options = dict()
    options['model'] = 'dcgan'
    options['loss'] = 'wgan'
    options['data'] = 'cifar10'
    options['metric'] = 'frobenius'
    options['learning_rate'] = 0.0002
    options['nz'] = 8
    options['batch_size'] = 500
    options['num_epochs'] = 2
    options['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if options['data'] == 'mnist':
        transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        dataset = datasets.MNIST(root='./data/', download=True, transform=transform)
    elif options['data'] == 'cifar10':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        dataset = datasets.CIFAR10(root='./data/', download=True, transform=transform)
    train_agda(dataset, manual_seed, options)