import argparse
import copy
import os
from pprint import pformat
import numpy as np
import torch
import torch.distributed as dist
from torch import nn

import optim
from decentralized_opt.data_parallel import DataParallel
from decentralized_opt.tensor_tools import all_reduce_tensors
from decentralized_opt.communication_graph import CommunicationGraph
import decentralized_opt.log as log


def get_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument('--optimizer', default='CHOCO_SGD', type=str, help='Optimizer')
    parser.add_argument('--dataset', default='mnist', type=str, help='Select the dataset')
    parser.add_argument('--sort', default=False, action='store_true', help='Divide training data by label')
    parser.add_argument('--model', default='default', type=str, help='Model structure')
    parser.add_argument('--graph_type', default='er', type=str)
    parser.add_argument('--graph_params', nargs="+", type=float)
    parser.add_argument('--kappa', default=0.8, type=float, help='Coefficients of the ring graph')
    parser.add_argument('--runs', default=1, type=int, help='Number of runs')
    parser.add_argument('--val_interval', default=10, type=int, help='Each evaluation interval')
    parser.add_argument('--cal_grad_norm', default=False, action='store_true', help='Whether to calculate the gradient norm according to val_interval')
    parser.add_argument('--grad_norm_form', default='avg_norm', type=str, help='How to calculate full gradient norm')

    # DNSGD
    parser.add_argument('--K', default=10, type=int)
    parser.add_argument('--Khat', default=10, type=int)
    # DNASA
    parser.add_argument('--alpha', default=1, type=float)

    # Training args
    parser.add_argument('--epochs', default=1, type=int, help='Number of epochs')
    parser.add_argument('--max_iters', default=0, type=int, help='Maximum number of iterations')
    parser.add_argument('--batch_size', default=100, type=int, help='Batch size per worker')
    parser.add_argument('--val_batch_size', default=100, type=int, help='Batch size of validation')
    parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate')
    parser.add_argument('--momentum', default=0, type=float, help='Momentum')
    parser.add_argument('--num_workers', default=0, type=int, help='Number of workers for data loader')
    parser.add_argument('--cpu', default=False, action='store_true', help='Use CPU')
    parser.add_argument('--deterministic', default=False, action='store_true', help='Use fixed random seed')

    # Distributed args
    parser.add_argument('--backend', default='nccl', type=str, help='Distributed backend',
                        choices=['nccl', 'gloo', 'mpi'])
    parser.add_argument('--n_peers', default=None, type=int, help='Number of iterations before synchroning')
    # In single-GPU mode, each iteration step is trained strictly in the order of node numbers.
    parser.add_argument('--single_gpu_sequential', default=False, action='store_true',
                        help='Force single GPU training sequentially across ranks')

    # Misc args
    parser.add_argument('--data_path', default=None, type=str, help='Path to the data folder')
    parser.add_argument('--output_dir', default=None, type=str, help='Path to the output folder')
    parser.add_argument('-v', '--verbosity', default='INFO', type=str, help='Verbosity of log',
                        choices=['DEBUG', 'INFO', 'WARN'])

    return parser


def parse_args(parser=None):
    if parser is None:
        parser = get_parser()

    args = parser.parse_args()

    # Load options from envs
    for name in ['MASTER_ADDR', 'MASTER_PORT']:
        setattr(args, name.lower(), os.environ[name])
    for name in ['RANK', 'WORLD_SIZE', 'LOCAL_RANK', 'WORLD_LOCAL_SIZE', 'WORLD_NODE_RANK']:
        setattr(args, name.lower(), int(os.environ[name]))
    args.node_rank = args.world_node_rank

    return args


def validate_args(args):
    # Check the backend and device compatibility
    if (not args.cpu) and (not torch.cuda.is_available()):
        log.warn('GPU is not availabel, using CPU instead')
        args.cpu = True

    args.device = torch.device('cpu') if args.cpu else torch.device(
        'cuda:%d' % (args.local_rank % torch.cuda.device_count()))

    if args.cpu and args.backend == 'nccl':
        log.warn('Setting backend to gloo when using CPU')
        args.backend = 'gloo'

    return args


def init(args):
    log.set_rank(args.rank)
    if args.output_dir is not None:
        log.set_directory(args.output_dir)
    log.set_level(args.verbosity)

    args = validate_args(args)

    log.info('Configurations:\n' + pformat(args.__dict__))

    log.info('world_size = %d, batch_size = %d, device = %s, backend = %s',
             args.world_size, args.batch_size, args.device, args.backend)

    if not args.cpu:
        torch.backends.cudnn.benchmark = True

    if args.deterministic:
        torch.manual_seed(args.rank)
        np.random.seed(args.rank)
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed(args.rank)

    dist.init_process_group(args.backend, world_size=args.world_size, rank=args.rank)


def wrap_model(model, args):
    graph = CommunicationGraph(args.world_size, rank=args.rank, graph_type=args.graph_type,
                               graph_params=args.graph_params)
    if args.optimizer == 'DNSGD':
        model = DataParallel(model)
        optimizer = optim.DNSGD(model, lr=args.lr, K=args.K, K_hat=args.Khat, world_size=args.world_size,
                                rank=args.rank, G=graph, kappa=args.kappa)
    elif args.optimizer == 'DSGD':
        model = DataParallel(model)
        optimizer = optim.DSGD(model, lr=args.lr, world_size=args.world_size, rank=args.rank, G=graph, kappa=args.kappa)
    elif args.optimizer == 'DSGT':
        model = DataParallel(model)
        optimizer = optim.DSGT(model, lr=args.lr, world_size=args.world_size, rank=args.rank, G=graph, kappa=args.kappa)
    elif args.optimizer == 'DNASA':
        model = DataParallel(model)
        optimizer = optim.DNASA(model, lr=args.lr, alpha=args.alpha, world_size=args.world_size, rank=args.rank,
                                G=graph, kappa=args.kappa)
    else:
        raise NotImplementedError(f"Optimizer {args.optimizer} is not implemented")

    return model, optimizer


def train(model, criterion, optimizer, train_loader, args,
          val_loader=None,
          val_train_loader=None,
          scheduler=None,
          single_gpu_sequential=False):
    # Results of the training process, loss and gradient norm for each step.
    training_res = []
    # Train results at the end of each epoch and the gradient norm of each epoch.
    train_res = []
    # Test results at the end of each epoch, validation accuracy of each epoch.
    val_res = []
    i = 0
    val_step = 0

    model.train()
    optimizer.zero_grad()

    # grad_norm = validate_norm(model, val_train_loader, args)
    # acc = validate_acc(model, val_loader, args)
    # train_res.append([0, grad_norm])
    # val_res.append([0, acc])

    # for r in range(dist.get_world_size()):
    #     if args.rank == r:
    #         log.info(f"[Epoch 0 (before training)] Rank{args.rank}: "
    #                  f"full gradient norm = {grad_norm:.6f}"
    #                  f" | acc = {acc:.6f}")

    #     dist.barrier()

    log.set_allowed_ranks(list(range(args.world_size)))
    log.info(f"Rank{args.rank}: Begin Training")

    for epoch in range(1, args.epochs + 1):
        if args.rank == 0:
            log.info(f"Epoch {epoch} Begin")

        dist.barrier()

        for _, (data_cpu, target_cpu) in enumerate(train_loader):
            i += 1

            if args.max_iters > 0 and i >= args.max_iters:
                if args.rank == 0:
                    log.info(f"Reached max_iters ({args.max_iters}), stopping training.")
                break

            if single_gpu_sequential:
                for r in range(dist.get_world_size()):
                    if args.rank == r:
                        data = data_cpu.to(device=args.device, non_blocking=True)
                        target = target_cpu.to(device=args.device, non_blocking=True)

                        output = model(data)
                        loss = criterion(output, target)
                        loss.backward()

                    dist.barrier()

                _, grad = optimizer.step(i)
                optimizer.zero_grad()

            else:
                data = data_cpu.to(device=args.device, non_blocking=True)
                target = target_cpu.to(device=args.device, non_blocking=True)

                output = model(data)
                loss = criterion(output, target)
                loss.backward()

                _, grad = optimizer.step(i)
                optimizer.zero_grad()

            grad_norm = grad.norm(2).item() if grad is not None else 0.0
            training_res.append([i, loss.item(), grad_norm])

            if args.single_gpu_sequential:
                dist.barrier()

            if i % args.val_interval == 0:
                val_step += 1

                if args.cal_grad_norm:
                    grad_norm = validate_norm(model, val_train_loader, args)
                    train_res.append([val_step, grad_norm])

                acc = validate_acc(model, val_loader, args)
                val_res.append([val_step, acc])

        # After an epoch of training, calculate the gradient norm and validation set accuracy.
        # grad_norm = validate_norm(model, val_train_loader, args)
        # acc = validate_acc(model, val_loader, args)
        # train_res.append([epoch, grad_norm])
        # val_res.append([epoch, acc])

        for r in range(dist.get_world_size()):
            if args.rank == r:
                log.info(f"[Epoch {epoch}] Rank{args.rank}: "
                         f"full gradient norm = {grad_norm:.6f}"
                         f" | acc = {acc:.6f}")

            dist.barrier()

        if scheduler is not None:
            log.debug('schedule.step() called')
            scheduler.step()

        if args.rank == 0:
            log.info(f"[Epoch {epoch}] Finished Training!")

    return train_res, val_res, training_res


def validate_norm(model, val_train_loader, args):
    # === Calculate the full training set gradient norm once after each epoch ===\
    def get_grad_norm(input_model):
        grads = [p.grad.detach().norm() ** 2 for p in input_model.parameters() if p.grad is not None]
        if not grads: return 0.0
        grad_norm = torch.sqrt(torch.sum(torch.stack(grads))).item()
        return grad_norm

    if args.grad_norm_form == 'avg_norm_3':
        model.zero_grad()
        model.train()
        criterion = nn.CrossEntropyLoss(reduction="sum")
        total_samples = 0

        if args.rank == 0:
            log.info(f"Begin calculate full gradient norm...")

        if args.single_gpu_sequential:
            for r in range(dist.get_world_size()):
                if args.rank == r:
                    total_samples = 0

                    for data_cpu, target_cpu in val_train_loader:
                        data = data_cpu.to(device=args.device, non_blocking=True)
                        target = target_cpu.to(device=args.device, non_blocking=True)

                        output = model(data)
                        loss = criterion(output, target)
                        loss.backward()
                        total_samples += data.size(0)

                    for p in model.parameters():
                        if p.grad is not None:
                            p.grad /= total_samples

                dist.barrier()

        else:
            for data_cpu, target_cpu in val_train_loader:
                data = data_cpu.to(device=args.device, non_blocking=True)
                target = target_cpu.to(device=args.device, non_blocking=True)

                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                total_samples += data.size(0)

            for p in model.parameters():
                if p.grad is not None:
                    p.grad /= total_samples

        norm = get_grad_norm(model)
        reqs = all_reduce_tensors([norm])
        for req in reqs:
            if req is not None:
                req.wait()

        model.zero_grad()

    elif args.grad_norm_form == 'avg_params_1':
        avg_model = copy.deepcopy(model)

        reqs = all_reduce_tensors([p for p in avg_model.module.parameters()])
        for req in reqs:
            if req is not None:
                req.wait()

        if args.rank == 0:
            log.info(f"Begin calculate full gradient norm...")

            avg_model.zero_grad()
            avg_model.train()
            criterion = nn.CrossEntropyLoss(reduction="sum")
            total_samples = 0

            if args.single_gpu_sequential:
                for r in range(dist.get_world_size()):
                    if args.rank == r:
                        total_samples = 0

                        for data_cpu, target_cpu in val_train_loader:
                            data = data_cpu.to(device=args.device, non_blocking=True)
                            target = target_cpu.to(device=args.device, non_blocking=True)

                            output = avg_model(data)
                            loss = criterion(output, target)
                            loss.backward()
                            total_samples += data.size(0)

                        for p in avg_model.parameters():
                            if p.grad is not None:
                                p.grad /= total_samples

                    dist.barrier()

            else:
                for data_cpu, target_cpu in val_train_loader:
                    data = data_cpu.to(device=args.device, non_blocking=True)
                    target = target_cpu.to(device=args.device, non_blocking=True)

                    output = avg_model(data)
                    loss = criterion(output, target)
                    loss.backward()
                    total_samples += data.size(0)

                for p in avg_model.parameters():
                    if p.grad is not None:
                        p.grad /= total_samples

            norm = get_grad_norm(avg_model)

    return norm


def validate_acc(model, val_loader, args):
    # === Validation set statistical accuracy ===
    correct, total = 0, 0

    if args.rank == 0:
        log.info(f"Begin calculate validation accuracy...")

    with torch.no_grad():
        if args.single_gpu_sequential:
            for r in range(dist.get_world_size()):
                if args.rank == r:
                    with torch.no_grad():
                        for data, target in val_loader:
                            data = data.to(args.device, non_blocking=True)
                            target = target.to(args.device, non_blocking=True)

                            output = model(data)
                            _, predicted = torch.max(output, 1)

                            correct += (predicted == target).sum().item()
                            total += target.size(0)

                        acc = correct / total

                dist.barrier()

        else:
            with torch.no_grad():
                for data, target in val_loader:
                    data = data.to(args.device, non_blocking=True)
                    target = target.to(args.device, non_blocking=True)

                    output = model(data)
                    _, predicted = torch.max(output, 1)

                    correct += (predicted == target).sum().item()
                    total += target.size(0)

            acc = correct / total

    return acc
