import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time
from collections import defaultdict, OrderedDict
import pickle
from tqdm import tqdm
import random
import math
import argparse
import json
import random
import pickle

from logistic import *
from graph import *

def load_dataset():
    train_val_dataset = datasets.MNIST('./data',
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                       ]))

    train_val_loader = torch.utils.data.DataLoader(train_val_dataset,
                                                   batch_size=len(train_val_dataset),
                                                   shuffle=False, num_workers=1, pin_memory=False)

    for data, label in train_val_loader:
        pass
    
    return data.reshape(-1, 28*28), label


def compute_average(model_list):
    avg_weight = sum([model_list[i].linear.weight for i in range(len(model_list))]) / len(model_list)
    avg_bias = sum([model_list[i].linear.bias for i in range(len(model_list))]) / len(model_list)

    return avg_weight, avg_bias


def main(args):
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    
    data, label = load_dataset()
    if args.model == "logistic":
        model_list = [LogisticClassification(data, label, args.batch_size) for _ in range(args.n_nodes)]
    elif args.model == "ridge":
        model_list = [RidgeClassification(data, label, args.batch_size) for _ in range(args.n_nodes)]
        
    optimizer_list = [optim.SGD(model_list[i].parameters(), lr=args.lr, momentum=0) for i in range(args.n_nodes)]


    if args.graph == "fc":
        mixing_matrix = torch.ones((args.n_nodes, args.n_nodes)) / args.n_nodes
    elif args.graph == "ring":
        mixing_matrix = Ring(args.n_nodes).w
    elif args.graph == "none":
        mixing_matrix = torch.eye(args.n_nodes)
    elif args.graph == "artificial":
        ring = Ring(args.n_nodes)
        mixing_matrix = generate_graph3(ring.w, args.average_spectral_gap)
    elif args.graph == "artificial_triple_ring":
        ring = kRing(args.n_nodes, 3)
        mixing_matrix = generate_graph3(ring.w, args.average_spectral_gap)
    elif args.graph == "artificial_torus":
        ring = Torus(8, 25)
        mixing_matrix = generate_graph3(ring.w, args.average_spectral_gap)
    elif args.graph == "artificial_line":
        ring = Line(args.n_nodes)
        mixing_matrix = generate_graph3(ring.w, args.average_spectral_gap)


    print(mixing_matrix)
    history = {"loss": [], "grad": []}
    history["spectral_gap"] = calc_spectral_gap(mixing_matrix)
    history["average_spectral_gap"] = calc_average_spectral_gap(mixing_matrix)
    print(history)
    
    for r in tqdm(range(args.num_round)):
        loss_list = []

        for i in range(args.n_nodes):
            optimizer_list[i].zero_grad()
        
        for i in range(args.n_nodes):
            loss_list.append(model_list[i].compute_loss())
            
        for i in range(args.n_nodes):
            loss_list[i].backward()

        for i in range(args.n_nodes):
            optimizer_list[i].step()

        for i in range(args.n_nodes):
            optimizer_list[i].zero_grad()
            
        new_weight_list = []
        new_bias_list = []
        for i in range(args.n_nodes):
            new_weight_list.append(mixing_matrix[i,i] * model_list[i].linear.weight)
            new_bias_list.append(mixing_matrix[i,i] * model_list[i].linear.bias)

            for j in range(args.n_nodes):
                if i == j:
                    continue
                new_weight_list[i] += mixing_matrix[i,j] * model_list[j].linear.weight
                new_bias_list[i] += mixing_matrix[i,j] * model_list[j].linear.bias
                
        for i in range(args.n_nodes):
            model_list[i].linear.weight.data = new_weight_list[i]
            model_list[i].linear.bias.data = new_bias_list[i]

        
        if r%10 == 0:    
            avg_weight, avg_bias = compute_average(model_list)
            loss, grad = model_list[0].get_full_gradient(avg_weight, avg_bias)
            print(loss, (grad**2).sum())
            history["loss"].append(loss.item())
            history["grad"].append((grad**2).sum().item())

    with open(args.log_path, mode='wb') as f:
        pickle.dump(history, f)
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default="logistic", type=str, help="logistic or ridge")
    parser.add_argument('--batch_size', default=10, type=int, help="batch size")
    parser.add_argument('--num_round', default=2000, type=int, help="number of rounds")
    parser.add_argument('--n_nodes', default=10, type=int, help="number of nodes")
    parser.add_argument('--lr', default=0.01, type=float, help="stepsize")
    parser.add_argument('--log_path', default="log/", type=str, help="path where results will be stored.")
    parser.add_argument('--graph', default="fc", type=str, help="underlying network topology. artificial_ring, artificial_line, or artificial_torus")
    parser.add_argument('--average_spectral_gap', default=100, type=float, help="1/n \sum_{i=2}^n \frac{\lambda_i^2}{1 - \lambda_i^2}")
    parser.add_argument('--seed', default=0, type=int)
    args = parser.parse_args()

    main(args)
    
