
import argparse
import copy
import pickle
import shutil
from tqdm import tqdm
import paddle
from sklearn import metrics

from networks.models import AutoencoderCifar
from networks.models import AutoencoderMnist
from networks.models import AutoencoderCeleba
from utils.dataloader import mnist_loader, cifar10_loader
import os
import paddle.nn.functional as F
from loguru import logger
import numpy as np
from PIL import Image
from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman
from utils.dataloader import get_dataloader
from collections import Counter


def get_parser():
    parser = argparse.ArgumentParser(description='Convolution VAE-Paddle')
    parser.add_argument("--data_root", type=str, default="/root/projects/AttackDefence/data/AttackDefence") # put data under this folder
    parser.add_argument('--result_dir', type=str, required=True)
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--save_dir', type=str, required=True, help='model save directory')
    parser.add_argument('--bs', type=int, default=128)
    parser.add_argument('--epoches', type=int, default=100)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
    parser.add_argument('--test_epoch', default=20, type=int)
    parser.add_argument('--log_freq', default=100, type=int)
    parser.add_argument('--num_workers', type=int, default=0)
    parser.add_argument('--input_height', type=int, default=32)
    parser.add_argument('--input_width', type=int, default=32)

    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--input_dim', type=int, default=28 * 28)
    parser.add_argument('--input_channels', type=int, default=1)

    args = parser.parse_args()
    return args


def loss_func(res, images, labels, logits):
    # 1. the reconstruction loss.
    reconstruction_loss = F.binary_cross_entropy(res, images, reduction='mean')
    # reconstruction_loss = F.mse_loss(res, images, reduction='sum')

    # cls loss
    cls_loss = F.cross_entropy(logits, labels)

    loss = 1e-5 * cls_loss + reconstruction_loss
    return loss, reconstruction_loss, cls_loss


def save_image(random_res, save_path):
    # print("ran res: ", random_res.shape)
    if isinstance(random_res, np.ndarray):
        ndarr = ((random_res * 255) + 0.5).clip(0, 255).transpose([0, 2, 3, 1]).squeeze(-1)
    else:
        ndarr = ((random_res * 255) + 0.5).clip(0, 255).transpose([0, 2, 3, 1]).squeeze(-1).numpy()
    img_blocks = []
    # print("ndarr", ndarr.shape)
    for i in range(2):
        img_blocks.append(np.hstack(ndarr[i * 8:(i + 1) * 8]))
    for i in range(2):
        res = np.vstack(img_blocks[:8])

    img = Image.fromarray(np.uint8(res))
    img.save(save_path)


def save_an_image(random_res, save_path):
    print("ran res: ", random_res.shape)
    if isinstance(random_res, np.ndarray):
        ndarr = ((random_res * 255) + 0.5).clip(0, 255).transpose([0, 2, 3, 1])#.squeeze(-1)
    else:
        ndarr = ((random_res * 255) + 0.5).clip(0, 255).transpose([0, 2, 3, 1]).squeeze(-1).numpy()
    # img_blocks = []
    # # print("ndarr", ndarr.shape)
    # for i in range(1):
    #     img_blocks.append(np.hstack(ndarr[i * 8:(i + 1) * 8]))
    # for i in range(1):
    #     res = np.vstack(img_blocks[:8])

    img = Image.fromarray(np.uint8(ndarr).squeeze())
    img.save(save_path)


def save_checkpoint(state, is_best, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    checkpoint_file = os.path.join(save_dir, 'checkpoint_{}.pdparams'.format(state.get('epoch')))
    best_file = os.path.join(save_dir, 'model_best.pdparams')
    paddle.save(state, checkpoint_file)
    if is_best:
        shutil.copyfile(checkpoint_file, best_file)


def extract_easy_pattern_image(model, train_loader):
    """
    model trained after first epoch
    train loader
    """
    easy_pattern_logits_records = {}
    easy_pattern_logits_images = {}
    count = 0
    model.eval()
    for data in tqdm(train_loader):
        count += 1
        images = data['input']
        labels = data['target']
        res, logits = model(images)

        pred_label = paddle.argmax(logits, axis=1)
        pred_label = pred_label.numpy()
        labels = labels.numpy().reshape(-1)
        logits = logits.numpy()
        images = images.numpy()
        # print("lab", labels.shape)
        # print("Plab", pred_label.shape)

        for i, lab in enumerate(labels):
            pred_lab = pred_label[i]
            if lab == pred_lab:
                # print("logits i ", logits[i].shape)
                log = np.max(logits[i])
                # print("log", log)
                if lab in easy_pattern_logits_records:
                    if easy_pattern_logits_records[lab] < log:
                        easy_pattern_logits_records[lab] = log
                        easy_pattern_logits_images[lab] = images[i]
                else:
                    easy_pattern_logits_records[lab] = log
                    easy_pattern_logits_images[lab] = images[i]

    print("ep log: ", easy_pattern_logits_records)
    # print("ep images: ", easy_pattern_logits_images)
    model.train()
    return easy_pattern_logits_images


def train(args):
    metric = AccuracyAndF1()

    label_statistic_hash = Counter()
    start_epoch = 0
    best_test_loss = np.finfo('f').max
    easy_pattern_epoch = 1

    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)
    if args.dataset == "cifar10":
        myAE = AutoencoderCifar()
    elif args.dataset == "mnist":
        args.input_height = 28
        args.input_width = 28
        myAE = AutoencoderMnist()
    elif args.dataset == "gtsrb":
        easy_pattern_epoch = 5
        myAE = AutoencoderCifar(cls_num=43)
    elif args.dataset == "celeba":
        args.input_height = 64
        args.input_width = 64
        myAE = AutoencoderCeleba(cls_num=8)
    else:
        raise ValueError("data set should be mnist or cifar10")

    train_loader = get_dataloader(args, train=True)
    test_loader = get_dataloader(args, train=False)
    
    optimizer = paddle.optimizer.Adam(parameters=myAE.parameters(), learning_rate=args.lr)
    logger.info('start training...')
    for epoch_id in range(start_epoch, args.epoches):
        if epoch_id == easy_pattern_epoch:
            print("label statistic", label_statistic_hash)  #Todo for celeba dataset, label is unbalanced ({0: 46088, 3: 31263, 7: 28115, 4: 19494, 1: 12993, 2: 9871, 6: 8831, 5: 6115})

            # exit(0)
            easy_pattern_images = extract_easy_pattern_image(myAE, train_loader)
            pickle.dump(easy_pattern_images, open(args.result_dir + '/easy_pattern_imgs.pkl', 'wb'))
            for k, v in easy_pattern_images.items():
                # res = np.array(easy_pattern_images.values())
                if args.dataset == "cifar10" or args.dataset == "gtsrb":
                    v = v.reshape(-1, 3, 32, 32)
                elif args.dataset == "mnist":
                    v = v.reshape(-1, 1, 28, 28)
                elif args.dataset == "celeba":
                    v = v.reshape(-1, 3, 64, 64)
                # print("res ", v.shape)
                save_an_image(v, '%s/easy_pattern_cls_%d.png' % (args.result_dir, k))

        for batch_id, data in enumerate(train_loader):
            images = data['input']
            labels = data['target']
            # print('numpy', labels.numpy())
            # print('numpy 0', labels.numpy()[0])
            # exit(0)
            label_statistic_hash.update(labels.numpy().squeeze())
            # continue
            # print("label h", label_statistic_hash)
            # exit(0)
            res, logits = myAE(images)
            loss, recon_loss, cls_loss = loss_func(res, images, labels, logits)
            optimizer.clear_grad()
            loss.backward()
            optimizer.step()

            if (batch_id + 1) % args.log_freq == 0:
                info = 'epoch: {}, batch: {}, recon_loss: {:.4f}, ' \
                       'cls_loss: {:.4f}, total_loss: {:.4f}'.format(
                    epoch_id + 1, batch_id + 1, recon_loss.item(), cls_loss.item(), loss.item()
                )
                logger.info(info)

        # testing
        if (epoch_id + 1) % args.test_epoch == 0:
            test_avg_loss = 0.0
            with paddle.no_grad():
                all_test_labels = []
                all_pred_labels = []
                for idx, test_data in enumerate(test_loader):
                    test_images = test_data['input']
                    test_labels = test_data['target']
                    test_res, test_logits = myAE(test_images)
                    test_loss, test_recon_loss, test_cls_loss = loss_func(test_res, test_images,
                                                                        test_labels, test_logits)
                    test_avg_loss += test_loss.item()
                    all_test_labels.extend(test_labels.numpy().squeeze())
                    preds_pos = paddle.argmax(paddle.nn.functional.softmax(test_logits), axis=-1).numpy().squeeze()
                    all_pred_labels.extend(preds_pos)

                test_avg_loss /= len(test_loader.dataset)

                rep = metrics.classification_report(all_test_labels, all_pred_labels)
                logger.info('\n' + rep)

                test_res, test_logits = myAE(test_images)
                if args.dataset == "cifar10" or args.dataset == "gtsrb":
                    random_res = test_res.reshape([-1, 3, 32, 32])
                    random_x = test_images.reshape([-1, 3, 32, 32])
                    save_image(random_res, '%s/random_sampled_%d.png' % (args.result_dir, epoch_id))
                    save_image(random_x, '%s/random_x_%d.png' % (args.result_dir, epoch_id))
                elif args.dataset == "mnist":
                    random_res = test_res.reshape([-1, 1, 28, 28])
                    random_x = test_images.reshape([-1, 1, 28, 28])
                    save_image(random_res, '%s/random_sampled_%d.png' % (args.result_dir, epoch_id))
                    save_image(random_x, '%s/random_x_%d.png' % (args.result_dir, epoch_id))
                elif args.dataset == "celeba":
                    random_res = test_res.reshape([-1, 3, 64, 64])
                    random_x = test_images.reshape([-1, 3, 64, 64])
                    save_image(random_res, '%s/random_sampled_%d.png' % (args.result_dir, epoch_id))
                    save_image(random_x, '%s/random_x_%d.png' % (args.result_dir, epoch_id))

                # save model
                is_best = test_avg_loss < best_test_loss
                best_test_loss = min(test_avg_loss, best_test_loss)
                save_checkpoint({
                    'epoch': epoch_id,
                    'best_test_loss': best_test_loss,
                    'state_dict': myAE.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, is_best, args.save_dir)


if __name__ == '__main__':
    args = get_parser()
    train(args)