import faulthandler
faulthandler.enable()

import os
from os.path import join
import random
from math import log

import numpy as np
# import tensorflow as tf
import torch
from torch import nn
# import tensorflow as tf
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.nn.parallel
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter

from utils.model import get_model
from utils.dataset import get_dataset
from utils.optim import get_optim
from utils.train import train_step
# from utils.train import generator_max_eigen, critic_max_eigen
from utils.tool import cycle
from collections import Counter

from config import args
from gragh import My_Graph

TEMP_DIR = 'temp_store'


def evaluate (model, valid_data, compute_loss=False):
    model.eval()
    COUNTER = 0
    ACCURACY = 0
    for x_, y_ in valid_data :
        x_, y_ = x_.cuda(), y_.cuda()
        x_, y_ = Variable(x_), Variable(y_)
        out = model(x_)
        _, predicted = torch.max(out, 1)
        COUNTER += y_.size(0)
        ACCURACY += float(torch.eq(predicted,y_).sum().cpu().data.numpy())
        if compute_loss:
            LOSS = nn.functional.cross_entropy(out, y_).item()
            return ACCURACY / float(COUNTER) *100.0, LOSS
    return ACCURACY / float(COUNTER) *100.0


# deprecated
# evaluate on adversarial examples
def evaluate_adversarial (model, valid_data, epsilon=0.05):
    model.eval()
    COUNTER = 0
    ACCURACY = 0
    for x_, y_ in valid_data :
        x_, y_ = x_.cuda(), y_.cuda()
        x_, y_ = Variable(x_,requires_grad=True), Variable(y_)
        loss_true = nn.CrossEntropyLoss()(model(x_),y_)
        loss_true.backward()
        x_grad = x_.grad
        x_adversarial = x_.clone()
        x_adversarial = x_adversarial.cuda()
        # x_adversarial.data = x_.data + epsilon * torch.sign(x_grad.data) * x_grad.data   
        # x_adversarial.data = x_.data + epsilon * torch.sign(x_grad.data)
        x_adversarial.data = x_.data + epsilon * torch.randn_like(x_)      
        
        x_.grad.data.zero_()
        out = model(x_adversarial)
        _, predicted = torch.max(out, 1)
        COUNTER += y_.size(0)
        ACCURACY += float(torch.eq(predicted,y_).sum().cpu().data.numpy())
    return ACCURACY / float(COUNTER) *100.0


def main_worker(rank, args):
    # code reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.nodes, rank=rank)

    train_loader, test_loader = get_dataset(args.nodes, rank, args.batchsize//args.nodes, args.label_num, args.s_rate)
    batches_in_epoch = len(train_loader)
    
    # Use GPU is available else use CPU.
    # device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu")
    first_device = 0
    device = (first_device + rank) % args.n_gpu
    torch.cuda.set_device(device)
    print("gpu ", device, " will be used.\n")
    args.device = device
    net = get_model(args.dataset)
    net.to(args.device)

    import torchsummary
    if rank == 0:
        if args.dataset == 'mnist':
            torchsummary.summary(model=net,input_size=(1, 28, 28))
        elif args.dataset == 'cifar10':
            torchsummary.summary(model=net,input_size=(3, 32, 32))

    # Create optimizers.
    optim_x = get_optim(net.parameters(), x=True)

    # Tensorboard
    if args.model_name is None:
        model_name = f'{args.dataset}_{args.model}_{args.optim}_{args.lr_x}_{args.lr_y}_eps_{args.epsilon}_i_{args.n_inner}_bs_{args.batchsize}_m_{args.label_num}'
        if 'tiada' in args.optim:
            model_name += f'_tiada_{args.alpha}_{args.beta}'
        # model_name += f"_seed_{args.seed}"
    else:
        model_name = args.model_name
    model_name += f"_topo_{args.topo}_{args.which_sum}_opt_comm_{args.opt_comm}_rank_{rank}"

    log_dir = join('logs', model_name)
    writer = SummaryWriter(log_dir)

    if args.topo == '1' or args.topo == 'sep' or args.nodes == 1:
        matrix = np.eye(args.nodes)
    elif args.topo == '2' or args.topo == 'full':
        matrix = np.ones((args.nodes, args.nodes))/args.nodes
    elif args.topo == '3' or args.topo == 'exp':
        peer = int(log(args.nodes-1, 2) + 1)
        matrix = np.eye(args.nodes)/(peer+1)
        for i in range(args.nodes):
            for j in range(peer):
                matrix[(i+2**j)%args.nodes, i] = 1/(peer+1)
    elif args.topo == '4' or args.topo == 'ring':
        matrix = np.eye(args.nodes)/3
        for i in range(args.nodes):
            matrix[i, (i-1+args.nodes) % args.nodes] = 1/3
            matrix[i, (i+1) % args.nodes] = 1/3
    elif args.topo == '5' or args.topo == 'dense':
        peer = args.nodes // 2
        matrix = np.eye(args.nodes) / (peer + 1)
        for i in range(args.nodes):
            matrix[i, (i + 1) % args.nodes] = 1 / (peer + 1)
            for j in range(peer):
                matrix[(i + 2 * j) % args.nodes, i] = 1 / (peer + 1)
    elif args.topo == '6' or args.topo == 'directed-ring':
        matrix = np.eye(args.nodes)/2
        for i in range(args.nodes):
            matrix[i, (i-1+args.nodes) % args.nodes] = 1/2

    # labels_for_this_process = []
    # for _, labels in train_loader:
    #     labels_for_this_process.extend(labels.tolist())

    # element_counts = Counter(labels_for_this_process)
    # print(f"Length of Rank {dist.get_rank()} labels: \n{len(labels_for_this_process)}")
    # for element, count in element_counts.items():
    #     print(f"{dist.get_rank()} label {element}: {count}")
    
    Weight_matrix = torch.from_numpy(matrix)
    graph = My_Graph(rank=rank, world_size=dist.get_world_size(), weight_matrix=matrix)
    out_edges, in_edges = graph.get_edges()
    
    # Training loop
    num_epoch = args.num_epoch
    total_steps = num_epoch * batches_in_epoch
    args.total_steps = total_steps
    args.step = 0
    args.outer_step = 0
    data_iter = iter(cycle(train_loader))
    # record a image for a total of num_recorded times
    num_record = num_epoch
    num_recorded = 0
    record_gap = total_steps // num_record

    if rank == 0:
        print("total_steps : ", total_steps)
    
    data_sample, _ = next(data_iter)
    delta = torch.zeros_like(data_sample).to(args.device)
    delta.requires_grad_()
    optim_y = get_optim([delta], x=False)
    if 'tiada' in args.optim:
        optim_x.opponent_optim = optim_y
    
    while args.step < total_steps:

        net.train()

        results, delta = train_step(args, data_iter, delta, net, optim_x, optim_y, Weight_matrix, out_edges, in_edges, rank)

        args.outer_step += 1

        if rank == 0:
            print('Step: ', args.step)

        if args.step >= record_gap * (num_recorded+1):
            
            results['train_acc'] = evaluate(net, train_loader)
            results['adv_acc'] = evaluate_adversarial(net, test_loader, args.epsilon)
            results['test_acc'], results['test_loss']= evaluate(net, test_loader, compute_loss=True)

            writer.add_scalar('train_acc', results["train_acc"],
                    global_step=args.step)
            writer.add_scalar('adv_acc', results["adv_acc"],
                    global_step=args.step)
            num_recorded = num_recorded + 1
            writer.add_scalar('test_acc', results["test_acc"],
                    global_step=args.step)
            writer.add_scalar('test_loss', results["test_loss"],
                    global_step=args.step)

        # Write loss
        writer.add_scalar('x grad norm', results["x_grad_norm"],
                global_step=args.step)
        writer.add_scalar('y grad norm', results["y_grad_norm"],
                global_step=args.step)

        writer.add_scalar('classification loss', results["classification_loss"],
                global_step=args.step)
        writer.add_scalar('total loss', results["total_loss"],
                global_step=args.step)
                
        writer.add_scalar('total grad norm x', results["x_total_grad_sum"],
                global_step=args.step)
        writer.add_scalar('total grad norm y', results["y_total_grad_sum"],
                global_step=args.step)
        
        writer.add_scalar('state grad norm sum x', results["x_state_sum_sum"],
                global_step=args.step)
        writer.add_scalar('state grad norm sum y', results["y_state_sum_sum"],
                global_step=args.step)

        # Exist if nan
        if np.isnan(results["classification_loss"]):
            writer.add_text("nan", "nan", global_step=args.step)
            exit(0)


if __name__=='__main__':
    mp.spawn(main_worker, nprocs=args.nodes, args=(args,))
