import argparse
import argparse
import collections
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils
import torch.utils.data.distributed
import torchvision
from tqdm import tqdm
from torchvision import transforms
from PIL import Image
from utils_lt import BNFeatureHookLT, lr_cosine_policy
from imbalance_cifar import IMBALANCECIFAR10

# import wandb

"""
TODO:
1. compute statistics for each class: BN+Conv
2. realize multi backbones
"""



def clip_cifar(image_tensor):
    """
    adjust the input based on mean and variance for cifar
    """
    mean = np.array([0.4914, 0.4822, 0.4465])
    std = np.array([0.2023, 0.1994, 0.2010])

    for c in range(3):
        m, s = mean[c], std[c]
        image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s)
    return image_tensor


def denormalize_cifar(image_tensor):
    """
    convert floats back to input for cifar
    """
    mean = np.array([0.4914, 0.4822, 0.4465])
    std = np.array([0.2023, 0.1994, 0.2010])

    for c in range(3):
        m, s = mean[c], std[c]
        image_tensor[:, c] = torch.clamp(image_tensor[:, c] * s + m, 0, 1)

    return image_tensor


def save_images(args, images, targets, ipc_ids):
    # for id in range(images.shape[0]):
        # if targets.ndimension() == 1:
        #     class_id = targets[id].item()
        # else:
        #     class_id = targets[id].argmax().item()

        # if not os.path.exists(args.syn_data_path):
        #     os.mkdir(args.syn_data_path)

        # save into separate folders
        # dir_path = "{}/new{:03d}".format(args.syn_data_path, class_id)
        # place_to_store = dir_path + "/class{:03d}_id{:03d}.jpg".format(class_id, ipc_id)
        # if not os.path.exists(dir_path):
        #     os.makedirs(dir_path)

        # image_np = images[id].data.cpu().numpy().transpose((1, 2, 0))
        # pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
        # pil_image.save(place_to_store)

    for id in range(images.shape[0]):
        if targets.ndimension() == 1:
            class_id = targets[id].item()
        else:
            class_id = targets[id].argmax().item()

        if not os.path.exists(args.syn_data_path):
            os.mkdir(args.syn_data_path)

        # save into separate folders
        dir_path = "{}/new{:03d}".format(args.syn_data_path, class_id)
        place_to_store = dir_path + "/class{:03d}_id{:03d}.jpg".format(class_id, ipc_ids[id])
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        
        image_np = images[id].data.cpu().numpy().transpose((1, 2, 0))
        pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
        pil_image.save(place_to_store)



def main_syn(args):
    if not os.path.exists(args.syn_data_path):
        os.makedirs(args.syn_data_path)

    # model_teacher = models.__dict__[args.arch_name](pretrained=True)

    import torchvision

    model_teacher = torchvision.models.get_model("resnet18", num_classes=10)
    model_teacher.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model_teacher.maxpool = nn.Identity()

    model_teacher = nn.DataParallel(model_teacher).cuda()

    checkpoint = torch.load(args.arch_path)
    # model_teacher.module.load_state_dict(checkpoint["state_dict"])
    try:
        model_teacher.load_state_dict(checkpoint["state_dict"])
    except RuntimeError:
        model_teacher.module.load_state_dict(checkpoint["state_dict"])
    
    if isinstance(model_teacher, nn.DataParallel):
        model_teacher = model_teacher.module
    

    # ckp = '/home/xxx/My-Dataset-Distillation/cifar_train/save/resnet18_E10/ckpt.pth'
    # checkpoint = torch.load(ckp)
    # model_teacher.load_state_dict(checkpoint['state_dict'])

    model_teacher.eval()
    for p in model_teacher.parameters():
        p.requires_grad = False

    hook_for_display = None

    if args.ipc_strategy != "fix":
        raise NotImplementedError
        ipc_end, cls_ipc = get_distillation_ipc_config(args)
        assert sum(cls_ipc) == args.ipc_end * 10
        assert sorted(cls_ipc, reverse=True) == cls_ipc
        assert ipc_end == cls_ipc[0]
        total_cls = np.zeros(10)

        # construct targets list
        for ipc_id in range(args.ipc_start, ipc_end):
            # find the end of the class
            cls_end = 0
            for ipc_ in cls_ipc:
                if ipc_ < ipc_id:
                    break
                else:
                    cls_end += 1
            total_cls[:cls_end+1] += 1
            print("ipc id = ", ipc_id)
            print("end class = ", cls_end)
            print("after distillation, each class has ", total_cls)
            print("after distillation, total number of images = ", np.sum(total_cls), "/", ipc_end * 10)
    # else:
    #     # get 10*10 targets
    #     targets_all = torch.LongTensor(np.arange(10).repeat(10))


    # for ipc_id in range(0,50):


    save_every = 100
    batch_size = args.batch_size
    best_cost = 1e4

    loss_r_feature_layers = []
    # for name, module in model_teacher.named_modules():
    #     full_name = args.arch_name + "." + name
    #     if isinstance(module, nn.BatchNorm2d):
    #         loss_r_feature_layers.append(BNFeatureHook(module))
    load_tag = True
    for name, module in model_teacher.named_modules():
        full_name = str(model_teacher.__class__) + "=" + name #TODO: this might cause name conflict when using multiple backbones
        if isinstance(module, nn.BatchNorm2d):
            _hook_module = BNFeatureHookLT(module,
                                        save_path=args.statistic_path,
                                        name=full_name,
                                        training_momentum=args.training_momentum,
                                        class_number=10,
                                        alpha=args.alpha)
            _hook_module.set_hook(pre=True)
            load_tag = load_tag & _hook_module.load_tag
            loss_r_feature_layers.append(_hook_module)

    if not load_tag:
        transform_train = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ]
        )

        if args.long_tail:
            trainset = IMBALANCECIFAR10(root=args.train_data_path, train=True, imb_type='exp', imb_factor=args.IR, rand_number=args.rand_value, download=False, transform=transform_train)
        else:
            trainset = torchvision.datasets.CIFAR10(root=args.train_data_path, train=True, download=False, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                num_workers=4,
                                                batch_size=256,
                                                drop_last=False,
                                                shuffle=True)
        
        with torch.no_grad():
            print("Start to compute statistics")
            for i, (data, targets) in tqdm(enumerate(train_loader)):
                data = data.cuda()
                targets = targets.cuda()
                for _loss_t_feature_layer in loss_r_feature_layers:
                    _loss_t_feature_layer.set_label(targets)
                _ = model_teacher(data)

            for _loss_t_feature_layer in loss_r_feature_layers:
                _loss_t_feature_layer.save()
            print("Finish computing statistics")
    else:
        print("Load statistics from {}".format(args.statistic_path))

    # setup target labels
    # targets_all = torch.LongTensor(np.random.permutation(1000))
    # targets_all = torch.LongTensor(np.arange(10).repeat(args.ipc_end))
    # ids_all = torch.LongTensor(np.arange(args.ipc_end).repeat(10).reshape(args.ipc_end, 10).T.ravel())

    for _loss_t_feature_layer in loss_r_feature_layers:
        _loss_t_feature_layer.set_hook(pre=False)

    # print(targets_all)
    ipc_id_range = range(args.ipc_start, args.ipc_end)
    targets_all_all = torch.LongTensor(np.arange(10))[None, ...].expand(len(ipc_id_range), 10).contiguous().view(-1)
    ipc_id_all = torch.LongTensor(ipc_id_range)[..., None].expand(len(ipc_id_range), 10).contiguous().view(-1)
    total_number = 10 * (ipc_id_range[-1] + 1 - ipc_id_range[0])
    turn_index = torch.LongTensor(np.arange(total_number)).view(len(ipc_id_range), 10).transpose(1, 0).contiguous().view(-1)
    #TODO: this is for cifar10 and ipc=10 for now
    for kk in range(0, total_number, batch_size):
        sub_turn_index = turn_index[kk:min(kk+batch_size, total_number)]
        targets = targets_all_all[sub_turn_index].to('cuda')
        ipc_ids = ipc_id_all[sub_turn_index].to('cuda')
        print(f"targets is set as: \n{targets}\n, ipc_ids is set as: \n{ipc_ids}")

        data_type = torch.float
        inputs = torch.randn((targets.shape[0], 3, 32, 32), requires_grad=True, device="cuda", dtype=data_type)

        iterations_per_layer = args.iteration
        lim_0, lim_1 = args.jitter, args.jitter

        optimizer = optim.Adam([inputs], lr=args.lr, betas=[0.5, 0.9], eps=1e-8)
        lr_scheduler = lr_cosine_policy(args.lr, 0, iterations_per_layer)  # 0 - do not use warmup
        criterion = nn.CrossEntropyLoss()
        criterion = criterion.cuda()

        for iteration in range(iterations_per_layer):
            # learning rate scheduling
            lr_scheduler(optimizer, iteration, iteration)

            # apply random jitter offsets
            off1 = random.randint(0, lim_0)
            off2 = random.randint(0, lim_1)
            inputs_jit = torch.roll(inputs, shifts=(off1, off2), dims=(2, 3))

            for _loss_t_feature_layer in loss_r_feature_layers:
                _loss_t_feature_layer.set_label(targets)
            # forward pass
            optimizer.zero_grad()
            outputs = model_teacher(inputs_jit)

            # R_cross classification loss
            loss_ce = criterion(outputs, targets)

            # R_feature loss
            rescale = [args.first_bn_multiplier] + [1.0 for _ in range(len(loss_r_feature_layers) - 1)]
            # loss_r_bn_feature = sum([mod.r_feature * rescale[idx] for (idx, mod) in enumerate(loss_r_feature_layers)])

            loss_r_bn_feature = [
                mod.r_feature.to(loss_ce.device) * rescale[idx] for (idx, mod) in enumerate(loss_r_feature_layers)
            ]
            loss_r_bn_feature = torch.stack(loss_r_bn_feature).sum()

            loss_aux = args.r_bn * loss_r_bn_feature

            loss = loss_ce + loss_aux

            if iteration % save_every == 0 and args.verifier:
                print("------------iteration {}----------".format(iteration))
                print("loss_ce", loss_ce.item())
                print("loss_r_bn_feature", loss_r_bn_feature.item())
                print("loss_total", loss.item())
                # comment below line can speed up the training (no validation process)
                if hook_for_display is not None:
                    acc_jit, _ = hook_for_display(inputs_jit, targets)
                    acc_image, loss_image = hook_for_display(inputs, targets)

                    metrics = {
                        'crop/acc_crop': acc_jit,
                        'image/acc_image': acc_image,
                        'image/loss_image': loss_image,
                    }
                    # wandb_metrics.update(metrics)

                metrics = {
                    'crop/loss_ce': loss_ce.item(),
                    'crop/loss_r_bn_feature': loss_r_bn_feature.item(),
                    'crop/loss_total': loss.item(),
                }
                print(metrics)
                # wandb_metrics.update(metrics)
                # wandb.log(wandb_metrics)

            # do image update
            loss.backward()
            optimizer.step()

            # clip color outlayers
            inputs.data = clip_cifar(inputs.data)

            if best_cost > loss.item() or iteration == 1:
                best_inputs = inputs.data.clone()

        # print(inputs.shape)
        if args.store_best_images:
            best_inputs = inputs.data.clone()  # using multicrop, save the last one
            best_inputs = denormalize_cifar(best_inputs)
            save_images(args, best_inputs, targets, ipc_ids)

        # to reduce memory consumption by states of the optimizer we deallocate memory
        optimizer.state = collections.defaultdict(dict)

    torch.cuda.empty_cache()


def parse_args():
    parser = argparse.ArgumentParser("SRe2L: recover data from pre-trained model")
    """Data save flags"""
    parser.add_argument(
        "--exp-name", type=str, default="test", help="name of the experiment, subfolder under syn_data_path"
    )
    parser.add_argument("--syn-data-path", type=str, default="./syn_data", help="where to store synthetic data")
    parser.add_argument("--store-best-images", action="store_true", help="whether to store best images")
    """Optimization related flags"""
    parser.add_argument("--batch-size", type=int, default=100, help="number of images to optimize at the same time")
    parser.add_argument("--iteration", type=int, default=1000, help="num of iterations to optimize the synthetic data")
    parser.add_argument("--lr", type=float, default=0.1, help="learning rate for optimization")
    parser.add_argument("--jitter", default=4, type=int, help="random shift on the synthetic data")
    parser.add_argument(
        "--r-bn", type=float, default=0.05, help="coefficient for BN feature distribution regularization"
    )
    parser.add_argument(
        "--first-bn-multiplier", type=float, default=10.0, help="additional multiplier on first bn layer of R_bn"
    )
    """Model related flags"""
    parser.add_argument(
        "--arch-name", type=str, default="resnet18", help="arch name from pretrained torchvision models"
    )
    parser.add_argument("--arch-path", type=str, default="")
    parser.add_argument("--verifier", action="store_true", help="whether to evaluate synthetic data with another model")
    parser.add_argument(
        "--verifier-arch",
        type=str,
        default="mobilenet_v2",
        help="arch name from torchvision models to act as a verifier",
    )
    parser.add_argument("--ipc-start", default=0, type=int)
    parser.add_argument("--ipc-end", default=1, type=int)
    parser.add_argument("--ipc-strategy", default="fix", type=str, choices=["fix", "same_total_number", "same_head_number"])
    parser.add_argument("--IR", default=None, type=float)

    """Statistics related flags"""
    parser.add_argument("--statistic-path", type=str, default="./statistics", help="where to store statistics")
    parser.add_argument("--training-momentum", type=float, default=0.8, help="momentum for training statistics")
    parser.add_argument("--alpha", type=float, default=0.5, help="alpha for training statistics")
    parser.add_argument("--train-data-path", type=str, default="/nas/dataset/dataset_distillation/cifar/cifar10", help="path to training data")
    parser.add_argument("--rand-value", type=int, default=0, help="random seed for the experiment")
    parser.add_argument("--long-tail", action="store_true", help="use long tail version of cifar")

    args = parser.parse_args()

    args.syn_data_path = os.path.join(args.syn_data_path, args.exp_name)
    return args

def get_distillation_ipc_config(args):
    if args.ipc_strategy == "fix":
        return args.ipc_end, [args.ipc_end] * 10
    
    if args.IR is None or args.IR > 0.99:
        return args.ipc_end, [args.ipc_end] * 10
    
    if args.ipc_end == 1:
        return args.ipc_end, [1] * 10
    

    imb_cifar10_dataset = IMBALANCECIFAR10(root='./data/cifar10', imb_type='exp', imb_factor=args.IR, rand_number=0)
    cls_num = imb_cifar10_dataset.get_cls_num_list()

    if args.ipc_strategy == "same_total_number":
        cls_ratio = cls_num / np.sum(cls_num)
        cls_ipc = cls_ratio * args.ipc_end * 10
        # make sure each cls has at least 1 image
        cls_ipc = [max(1, int(round(x))) for x in cls_ipc]
        # make sure the sum of cls_ipc is equal to args.ipc_end*10
        cls_ipc[0] += args.ipc_end * 10 - sum(cls_ipc)


        ipc_end = cls_ipc[0]
    elif args.ipc_strategy == "same_head_number":
        raise NotImplementedError
    else:
        raise NotImplementedError
    
    return ipc_end, cls_ipc

if __name__ == "__main__":

    args = parse_args()
    args.milestone = 1

    # if not wandb.api.api_key:
    #     wandb.login(key='')
    # wandb.init(project='sre2l-cifar', name=args.exp_name + "_2")
    # global wandb_metrics
    # wandb_metrics = {}
    main_syn(args)
    # wandb.finish()
