import sys
from functools import partial

import math
import torch.utils
import torch.utils.data.distributed
import networkx as nx
from networkx.algorithms.dag import topological_sort
import torch.nn.functional as F
from tqdm import trange, tqdm

from reactnet import reactnet

sys.path.append('.')
import mobilenetv2
from utils import ScaledWeightConv2d
from cifar_torch import cifar100, cifar10
from resnet_pcifar import ResNet18
from extract_cov import combine_cov
from baseline.greedy.reactnet import reactnet as reactnet_orig

from baseline.utils.utils import *
from baseline.utils import KD_loss


def main():
    if not torch.cuda.is_available():
        sys.exit(1)
    start_t = time.time()

    cudnn.benchmark = True
    cudnn.enabled = True
    logging.info("args = %s", args)

    if args.dataset == 'cifar100':
        # load teacher for cifar100
        model_teacher = mobilenetv2.MobileNetV2(
            num_classes=100,
            first_layer_type=ScaledWeightConv2d,
            expansion_layer=ScaledWeightConv2d,
            # depthwise_layer=ScaledWeightConv2d, # skip depthwise for now
            pointwise_layer=ScaledWeightConv2d,
            shortcut_layer=ScaledWeightConv2d,
            conv2_layer=ScaledWeightConv2d,
        )
        model_name = f'/d1/xxx/DBQ/mbnet_cifar100_state_dict.pt'
        state_dict = torch.load(model_name, map_location='cpu')
        model_teacher.load_state_dict(state_dict, strict=False)
    elif args.dataset == 'cifar10':
        model_teacher = ResNet18(10)
        model_name = f'/d1/xxx/pytorch-cifar/checkpoint/ckpt.pth'
        state_dict = torch.load(model_name, map_location='cpu')['net']
        state_dict = {key.replace("module.", ""): value for key, value in state_dict.items()}
        model_teacher.load_state_dict(state_dict, strict=False)

    model_teacher = nn.DataParallel(model_teacher).cuda()
    for p in model_teacher.parameters():
        p.requires_grad = False
    model_teacher.eval()

    if args.dataset == 'cifar100':
        num_classes = 100
    elif args.dataset == 'cifar10':
        num_classes = 10
    else:
        assert False

    model_student = reactnet(num_classes=num_classes, variant_C=args.variant_C)

    logging.info('student:')
    logging.info(model_student)
    model_student = nn.DataParallel(model_student).cuda()

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    criterion_smooth = CrossEntropyLabelSmooth(100, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()
    criterion_kd = KD_loss.DistributionLoss()

    all_parameters = model_student.parameters()
    weight_parameters = []
    for pname, p in model_student.named_parameters():
        if p.ndimension() == 4 or 'conv' in pname:
            weight_parameters.append(p)
    weight_parameters_id = list(map(id, weight_parameters))
    other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))

    optimizer = torch.optim.Adam(
        [{'params': other_parameters},
         {'params': weight_parameters, 'weight_decay': args.weight_decay}],
        lr=args.learning_rate, )

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: (1.0 - step / args.epochs), last_epoch=-1)
    start_epoch = 0
    best_top1_acc = 0

    # load training data
    if args.dataset == 'cifar100':
        train_loader, val_loader = cifar100(args.batch_size, args.workers)
    elif args.dataset == 'cifar10':
        train_loader, val_loader = cifar10(args.batch_size, args.workers)
    else:
        assert False

    if args.dataset == 'cifar100':
        checkpoint_tar = args.base_dir + '/models/checkpoint_ba.pth.tar'
    else:
        if args.variant_C:
            checkpoint_tar = args.base_dir + '/log_cifar10_C_1/model_best.pth.tar'
        else:
            checkpoint_tar = args.base_dir + '/log_cifar10_1/model_best.pth.tar'

    checkpoint = torch.load(checkpoint_tar)

    if False:
        model_orig = reactnet_orig(num_classes=num_classes).cuda()
        model_orig.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()})
        graph = make_graph(model_orig)
        layers = list(topological_sort(graph))
        modules = dict(model_orig.named_modules())
        orig_sv = {}
        cov_iter = None
        for layer_name in layers:
            if layer_name == 'fc':
                continue
            print(layer_name)
            if layer_name.endswith('_down'):
                cur_layers = [modules[layer_name + '1'], modules[layer_name + '2']]
            else:
                cur_layers = [modules[layer_name]]
            dim = np.prod(cur_layers[0].weight.shape[1:])
            model_orig.eval()
            cov, cov_iter = dataset_cov(
                train_loader, cov_iter, None,
                partial(extract_patches, model_orig, cur_layers[0]),
                dim, 1.0)
            s, V = eigh(cov)
            log_s = torch.log(torch.clip(s, min=0.0) + 1e-6)
            orig_sv[layer_name] = {'log_s': log_s, 'V': V}
        torch.save(orig_sv, "orig_sv.pth")
    else:
        orig_sv = torch.load("orig_sv.pth", map_location='cuda')

    model_student.load_state_dict(checkpoint['state_dict'], strict=False)

    checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar')
    if os.path.exists(checkpoint_tar):
        logging.info('loading checkpoint {} ..........'.format(checkpoint_tar))
        checkpoint = torch.load(checkpoint_tar)
        start_epoch = checkpoint['epoch'] + 1
        best_top1_acc = checkpoint['best_top1_acc']
        model_student.load_state_dict(checkpoint['state_dict'], strict=False)
        logging.info("loaded checkpoint {} epoch = {}".format(checkpoint_tar, checkpoint['epoch']))

    # adjust the learning rate according to the checkpoint
    for epoch in range(start_epoch):
        scheduler.step()

    q_orig = calc_q_orig(checkpoint['state_dict'], model_student.state_dict(), orig_sv.keys())
    q_ours = calc_q_ours(checkpoint['state_dict'], model_student.state_dict(), orig_sv)
    print(f"q_orig={q_orig}")
    print(f"q_ours={q_ours}")

    # train the model
    epoch = start_epoch
    while epoch < args.epochs:
        train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, model_teacher,
                                                          criterion_kd, optimizer, scheduler)
        valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args)

        q_orig = calc_q_orig(checkpoint['state_dict'], model_student.state_dict(), orig_sv.keys())
        q_ours = calc_q_ours(checkpoint['state_dict'], model_student.state_dict(), orig_sv)
        print(f"q_orig={q_orig}")
        print(f"q_ours={q_ours}")

        is_best = False
        if valid_top1_acc > best_top1_acc:
            best_top1_acc = valid_top1_acc
            is_best = True

        save_checkpoint({
            'epoch': epoch,
            'state_dict': model_student.state_dict(),
            'best_top1_acc': best_top1_acc,
            'optimizer': optimizer.state_dict(),
        }, is_best, args.save)

        epoch += 1

    training_time = (time.time() - start_t) / 3600
    print('total training time = {} hours'.format(training_time))


def binary_weights(weights, shape):
    real_weights = weights.view(shape)
    scaling_factor = torch.mean(abs(real_weights), dim=[1, 2, 3], keepdim=True)
    # print(scaling_factor, flush=True)
    scaling_factor = scaling_factor.detach()
    binary_weights_no_grad = scaling_factor * torch.sign(real_weights)
    cliped_weights = torch.clamp(real_weights, -1.0, 1.0)
    binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
    return binary_weights


@torch.no_grad()
def calc_q_orig(orig_states, cur_states, keys):
    orig_states = dict(orig_states)
    cur_states = dict(cur_states)
    keys = list(keys)

    results = {}
    for key in keys:
        if key.endswith('_down'):
            orig_w = torch.cat([orig_states[f'module.{key}1.weight'], orig_states[f'module.{key}2.weight']], dim=0)
            cur_w = torch.cat([cur_states[f'module.{key}1.weights'], cur_states[f'module.{key}2.weights']], dim=0)
            cur_w = binary_weights(cur_w, orig_w.shape)
        else:
            orig_w = orig_states[f'module.{key}.weight']
            cur_w = cur_states[f'module.{key}.weights']
            cur_w = binary_weights(cur_w, orig_w.shape)
        dist = ((cur_w - orig_w)**2).sum([1,2,3]).mean(0)
        results[key] = dist.item()

    return results


def calc_q_ours(orig_states, cur_states, orig_sv):
    orig_states = dict(orig_states)
    cur_states = dict(cur_states)

    results = {}
    for key, sv in orig_sv.items():
        if key.endswith('_down'):
            orig_w = torch.cat([orig_states[f'module.{key}1.weight'], orig_states[f'module.{key}2.weight']], dim=0)
            cur_w = torch.cat([cur_states[f'module.{key}1.weights'], cur_states[f'module.{key}2.weights']], dim=0)
            cur_w = binary_weights(cur_w, orig_w.shape)
        else:
            orig_w = orig_states[f'module.{key}.weight']
            cur_w = cur_states[f'module.{key}.weights']
            cur_w = binary_weights(cur_w, orig_w.shape)
        orig_w = orig_w.reshape(orig_w.shape[0], -1).T
        cur_w = cur_w.reshape(cur_w.shape[0], -1).T
        log_s = sv['log_s']
        V = sv['V']
        M = torch.diag(torch.exp(log_s)) @ V
        dist = ((M @ (cur_w - orig_w))**2).sum([0]).mean()
        results[key] = dist.item()

    return results


def eigh(cov):
    try:
        return torch.linalg.eigh(cov)
    except torch._C._LinAlgError:
        return torch.linalg.eigh(cov)


@torch.no_grad()
def extract_patches(model, layer, X):
    X = model.inputs_for(layer, X)
    if isinstance(layer, torch.nn.Conv2d):
        X = F.unfold(
            X,
            kernel_size=layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride,
        )  # (batch, ch*ks*ks, height*width)
        X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
        X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


def dataset_cov(train_loader, loader_instance, permute_indices, get_patches, dim, sample_portion=1.0):
    count = torch.zeros((), device='cuda', dtype=torch.long)
    mean = torch.zeros(dim, device='cuda', dtype=torch.float64)
    cov = torch.zeros(dim, dim, device='cuda', dtype=torch.float64)
    iters = math.ceil(len(train_loader) * sample_portion)
    if loader_instance is None:
        loader_instance = iter(train_loader)
    for _ in trange(iters):
        try:
            X, y = next(loader_instance)
        except StopIteration:
            loader_instance = iter(train_loader)
            X, y = next(loader_instance)
        X = X.cuda(non_blocking=True)
        patches = get_patches(X)  # (dim, batch*height*width)
        if permute_indices is not None:
            patches = patches[permute_indices]

        other_count = torch.tensor(patches.size(1), dtype=torch.long)
        other_mean = patches.mean(1)
        other_cov = torch.cov(patches)

        if count == 0:
            count.copy_(other_count)
            mean.copy_(other_mean)
            cov.copy_(other_cov)
        else:
            combine_cov(
                count, mean, cov,
                other_count, other_mean, other_cov
            )

    return cov.float(), loader_instance


def train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')

    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    model_student.train()
    model_teacher.eval()
    end = time.time()

    for param_group in optimizer.param_groups:
        cur_lr = param_group['lr']
    print('learning_rate:', cur_lr)

    for i, (images, target) in enumerate(train_loader):
        data_time.update(time.time() - end)
        images = images.cuda()
        target = target.cuda()
        target = target.flatten()

        # compute output
        logits_student = model_student(images)
        logits_teacher = model_teacher(images)
        loss = criterion(logits_student, logits_teacher)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(logits_student, target, topk=(1, 5))
        n = images.size(0)
        losses.update(loss.item(), n)  # accumulated loss
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        progress.display(i)

    scheduler.step()

    return losses.avg, top1.avg, top5.avg


def validate(epoch, val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluation mode
    model.eval()
    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.cuda()
            target = target.cuda()
            target = target.flatten()

            # compute output
            logits = model(images)
            loss = criterion(logits, target)

            # measure accuracy and record loss
            pred1, pred5 = accuracy(logits, target, topk=(1, 5))
            n = images.size(0)
            losses.update(loss.item(), n)
            top1.update(pred1[0], n)
            top5.update(pred5[0], n)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            progress.display(i)

        print(' * acc@1 {top1.avg:.3f} acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg


if __name__ == '__main__':
    parser = argparse.ArgumentParser("reactnet")
    parser.add_argument('--batch_size', type=int, default=128, help='batch size')
    parser.add_argument('--epochs', type=int, default=250, help='num of training epochs')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='init learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--save', type=str, default='./models', help='path for saving trained models')
    parser.add_argument('--data', metavar='DIR', help='path to dataset')
    parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
    parser.add_argument('--teacher', type=str, default='resnet34', help='path of ImageNet')
    parser.add_argument('-j', '--workers', default=20, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--base-dir', default='/d1/xxx/DBQ', type=str)
    parser.add_argument('--dataset', type=str, default='cifar100')
    parser.add_argument('--variant_C', default=False, action='store_true')
    args = parser.parse_args()

    if not os.path.exists('log'):
        os.mkdir('log')

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join('log/log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    main()
