import sys

import torch

import models
from models.res_adapt import ResNet18_adapt
from utils import *
from args import parse_train_args
from datasets import make_dataset

from pdb import set_trace
from collections import OrderedDict


def loss_compute(args, model, criterion, outputs, targets):
    if args.loss == 'CrossEntropy':
        loss = criterion(outputs[0], targets)
    elif args.loss == 'MSE':
        loss = criterion(outputs[0], nn.functional.one_hot(targets).type(torch.FloatTensor).to(args.device))

    # Now decide whether to add weight decay on last weights and last features
    if args.sep_decay:
        # Find features and weights
        features = outputs[1]
        w = model.fc.weight
        b = model.fc.bias
        lamb = args.weight_decay / 2
        lamb_feature = args.feature_decay_rate / 2
        loss += lamb * (torch.sum(w ** 2) + torch.sum(b ** 2)) + lamb_feature * torch.sum(features ** 2)

    return loss


def cal_sparsity(output):
    sparsity_dict = dict()
    for name, item in output.items():
        sparsitys = []
        for i in item:
            sparsity = (i < 1e-6).float().mean()
            sparsitys.append(sparsity.item())
        for cnt, s in enumerate(sparsitys, 1):
            sparsity_dict[name + ".conv{}".format(cnt)] = s
    return sparsity_dict


def format_dict(x):
    out_str = ""
    for key, item in x.items():
        out_str += "{}: {:.2f} ".format(key, item)
    return out_str


def fix_weight(model):
    for name, param in model.named_parameters():
        if "bias" not in name and "fc" not in name:
            print("fix {}".format(name))
            param.requires_grad = False


def reformat_dict(output):
    feature_dict = dict()
    for name, item in output.items():
        for cnt, i in enumerate(item):
            feature_dict[name + ".{}".format(cnt)] = i
    return feature_dict


def trainer(args, model, trainloader, epoch_id, criterion, optimizer, scheduler, logfile, sparsityFlipCal=None):

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    sparsity = AverageDict()

    # if epoch_id < args.only_bias_free:
    #     fix_weight(model)

    print_and_save('\nTraining Epoch: [%d | %d] LR: %f' % (epoch_id + 1, args.epochs, scheduler.get_last_lr()[-1]), logfile)
    for batch_idx, (inputs, targets, index) in enumerate(trainloader):

        inputs, targets = inputs.to(args.device), targets.to(args.device)

        model.train()

        outputs = model(inputs, return_mid_features=True)

        sparsity_val = cal_sparsity(outputs[2])
        sparsity.update(sparsity_val)

        if epoch_id == 0:
            sparsityFlipCal.update(reformat_dict(outputs[2]), index, set_compare=True)
        else:
            sparsityFlipCal.update(reformat_dict(outputs[2]), index)

        if args.save_feature and batch_idx == 0:
            torch.save(outputs[2], args.save_path + "/feature_epoch_{}".format(str(epoch_id + 1).zfill(3)))

        if args.sep_decay:
            loss = loss_compute(args, model, criterion, outputs, targets)
        else:
            if args.loss == 'CrossEntropy':
                loss = criterion(outputs[0], targets)
            elif args.loss == 'MSE':
                loss = criterion(outputs[0], nn.functional.one_hot(targets).type(torch.FloatTensor).to(args.device))

        if epoch_id > 1 and batch_idx == 0:
            compare_summary = sparsityFlipCal.compare(feature_as_compare=args.flip_change_compare_with_last)
            print_and_save("sparse flip summary: {}".format(format_dict(compare_summary)), logfile)

        optimizer.zero_grad()
        loss.backward()
        if args.bias_only_decrease:
            model.clip_grad_bias()
        optimizer.step()

        # measure accuracy and record loss
        model.eval()
        outputs = model(inputs)

        prec1, prec5 = compute_accuracy(outputs[0].detach().data, targets.detach().data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        if batch_idx % 10 == 0:
            print_and_save('[epoch: %d] (%d/%d) | Loss: %.4f | top1: %.4f | top5: %.4f ' %
                           (epoch_id + 1, batch_idx + 1, len(trainloader), losses.avg, top1.avg, top5.avg), logfile)
            print_and_save("sparsity: {}".format(format_dict(sparsity.val)), logfile)

    print_and_save("sparsity epoch mean is {}".format(format_dict(sparsity.avg)), logfile)

    scheduler.step()


def measure_distance(model, init_weights):
    model_state_dict = model.state_dict()
    distance = dict()
    for key, item in model_state_dict.items():
        if item.dtype == torch.float32:
            distance[key] = torch.norm(item - init_weights[key]).item()

    return distance


def train(args, model, trainloader):

    criterion = make_criterion(args)
    optimizer = make_optimizer(args, model)
    scheduler = make_scheduler(args, optimizer)

    logfile = open('%s/train_log.txt' % (args.save_path), 'w')

    print_and_save('# of model parameters: ' + str(count_network_parameters(model)), logfile)
    print_and_save('--------------------- Training -------------------------------', logfile)

    init_weights = OrderedDict()
    for key, item in model.state_dict().items():
        init_weights[key] = item.detach().clone()

    sparsityFlipCal = SparsityFlipCal()
    for epoch_id in range(args.epochs):

        trainer(args, model, trainloader, epoch_id, criterion, optimizer, scheduler, logfile, sparsityFlipCal=sparsityFlipCal)
        # measure distance with the init model
        distance_dict = measure_distance(model, init_weights)
        print_and_save("model distance: {}".format(format_dict(distance_dict)), logfile)

        if epoch_id % 50 == 49:
            torch.save(model.state_dict(), args.save_path + "/epoch_" + str(epoch_id + 1).zfill(3) + ".pth")

    logfile.close()


def main():
    args = parse_train_args()

    set_seed(manualSeed=args.seed)

    if args.optimizer == 'LBFGS':
        sys.exit('Support for training with 1st order methods!')

    device = torch.device("cuda:"+str(args.gpu_id) if torch.cuda.is_available() else "cpu")
    args.device = device

    trainloader, _, num_classes = make_dataset(args.dataset, args.data_dir, args.batch_size, args.sample_size, SOTA=args.SOTA)

    if args.model == "MLP":
        model = models.__dict__[args.model](hidden=args.width, depth=args.depth, fc_bias=args.bias,
                                            mlp_bias=args.mlp_bias, mlp_bias_multiply=args.mlp_bias_multiply,
                                            num_classes=num_classes, affine=(not args.no_affine)).to(device)
    elif args.model == "ResNet18_adapt":
        model = ResNet18_adapt(width=args.width, num_classes=num_classes, fc_bias=args.bias).to(device)
    else:
        model = models.__dict__[args.model](num_classes=num_classes, fc_bias=args.bias,
                                            ETF_fc=args.ETF_fc, fixdim=args.fixdim, SOTA=args.SOTA,
                                            norm_bias=args.norm_bias).to(device)

    train(args, model, trainloader)


if __name__ == "__main__":
    main()
