import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
import wandb
import torchvision
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset_old, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
from resnet_big import SupConResNet

import random

from datetime import datetime


def main():

    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--model', type=str, default='DisentangleConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='SS', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
    parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=3, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop
    parser.add_argument('--Iteration', type=int, default=30000, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='noise', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
    parser.add_argument('--load_debiased', action='store_true', help='whether to load debiased model.')
    parser.add_argument('--load_biased', action='store_true', help='whether to load debiased model.')
    parser.add_argument('--shuffle_bias', action='store_true', help='whether to shuffle biases')
    parser.add_argument('--reweight', action='store_true', help='whether to reweight the samples')
    parser.add_argument('--conflict_percent', type=str, default='0.5', help='conflict sample percentage')
    parser.add_argument('--dim_start', type=int, default=0, help='training iterations')
    parser.add_argument('--eval_interval', type=int, default=2000, help='training iterations')
    parser.add_argument('--dim_end', type=int, default=1152, help='training iterations')
    parser.add_argument('--debias_model_path', type=str, default='', help='path to save results')
    parser.add_argument('--bias_model_path', type=str, default='', help='path to save results')
    parser.add_argument('--bias_label_model_path', type=str, default='', help='path to save results')
    parser.add_argument('--supcon_model_path', type=str, default='', help='path to save results')
    parser.add_argument('--use_bias_label_model', action='store_true', help='whether to use bias label model to compute bias label')
    parser.add_argument('--use_sim_matrix', action='store_true', help='whether to use supervised constrastive learning to compute weight.')
    parser.add_argument('--random_start', action='store_true', help='choose a random start point to match')
    parser.add_argument('--use_kde', action='store_true', help='whether to use kde to compute the sample weights.')
    parser.add_argument('--cutoff_score', type=float, default=130, help='cutting off score for smoothing')
    parser.add_argument('--kernel_h', type=float, default=0.1, help='the width for computing KDE')
    parser.add_argument('--kernel_temp', type=float, default=0.1, help='the temperature for cmputing KDE.')
    parser.add_argument('--mimic_selecmix', action='store_true', help='whether to use supervised constrastive learning to compute weight.')

    args = parser.parse_args()
    if args.use_sim_matrix and args.use_kde:
        exit("cannot use both sim matrix and kde")



    args.method = 'DM'
    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
    # args.dsa = False
    args.dc_aug_param = None

    run_name = datetime.now().strftime("%m/%d/%Y, %H:%M:%S") + '_' + args.method + '_' + args.dataset + '_' + args.model + '_pct' + str(args.conflict_percent) + '_ipc_' + str(args.ipc)
    run_name += '_' + str(args.dim_start) + ' : ' + str(args.dim_end)

    wandb.init(sync_tensorboard=False, project='DebiasingDC', job_type='CleanRepo', config=args, name=run_name)

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    eval_it_pool = np.arange(0, args.Iteration+1, args.eval_interval).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.
    if args.model == 'MLP_disentangle':
        model_eval_pool = ['BiasMLP']
    elif 'mlp' in args.model or 'MLP' in args.model:
        model_eval_pool = ['MLP']
    else:
        model_eval_pool = ['ConvNet']
    print('eval_it_pool: ', eval_it_pool)
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path, args.conflict_percent)

    # load bias label model to produce bias label using a model instead of using ground truth.
    if args.use_bias_label_model:
        bias_label_model = get_network(args.model, channel, num_classes, im_size).to(args.device)
        bias_label_model.load_state_dict(torch.load(args.bias_label_model_path, map_location='cuda:0')['state_dict'])
        bias_label_model.eval()

    # load the supcon model
    if args.use_sim_matrix or args.use_kde:
        model_c = SupConResNet(name='resnet18').to(args.device)
        model_c.load_state_dict(torch.load(args.supcon_model_path, map_location='cuda:0')['state_dict'])


    accs_all_exps = dict() # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    data_save = []


    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)
        print('Evaluation model pool: ', model_eval_pool)

        ''' organize the real dataset '''
        images_all = []
        labels_all = []
        indices_class = [[] for c in range(num_classes)]

        images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        bias_labels_all = [dst_train[i][2] for i in range(len(dst_train))]
        for i, lab in enumerate(labels_all):
            indices_class[lab].append(i)
        images_all = torch.cat(images_all, dim=0).to(args.device)
        labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
        bias_labels_all = torch.tensor(bias_labels_all, dtype=torch.long, device=args.device)



        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))

        def get_images(c, n): # get random n images from class c
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle]
        
        def get_images_and_labels(c, n): # get random n images from class c and their labels.
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return images_all[idx_shuffle], labels_all[idx_shuffle], bias_labels_all[idx_shuffle]
        
        def compute_kde(sim_feature, true_bias_label, temp=0.1, h=0.1):
            kernel = torch.cdist(sim_feature, sim_feature, compute_mode='donot_use_mm_for_euclid_dist')
            kernel = torch.exp(-(kernel**2)/h) # ablation study on this(h), default value is 0.1
            kernel = (1/kernel.mean(dim=-1))

            reblanced_weight = 1/(1+torch.exp(kernel-args.cutoff_score)) # ablation study on this 1000 and cutoff score.
            reblanced_weight = torch.nn.functional.softmax(reblanced_weight/temp) # ablation study on this temp, the default value is 0.1
            return reblanced_weight

        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


        ''' initialize the synthetic data '''
        image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

        if args.init == 'real':
            print('initialize synthetic data from random real images')
            for c in range(num_classes):
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
        else:
            print('initialize synthetic data from random noise')


        ''' training '''
        optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
        optimizer_img.zero_grad()
        print('%s training begins'%get_time())

        max_accuracy = 0.0
        for it in range(args.Iteration+1):

            ''' Evaluate synthetic data '''
            if it in eval_it_pool:
                for model_eval in model_eval_pool:
                    print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))

                    print('DSA augmentation strategy: \n', args.dsa_strategy)
                    print('DSA augmentation parameters: \n', args.dsa_param.__dict__)

                    accs = []
                    for it_eval in range(args.num_eval):
                        net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                        image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                        _, acc_train, acc_test = evaluate_synset_old(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                        accs.append(acc_test)
                    current_accuracy = np.mean(accs)
                    wandb.log({"Accuracy/{}".format(model_eval) : current_accuracy}, step=it)
                    wandb.log({"Std/{}".format(model_eval) : np.std(accs)}, step=it)
                    max_accuracy = max(max_accuracy, current_accuracy)
                    wandb.log({"Max Accuracy/{}".format(model_eval) : max_accuracy}, step=it)
                    print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))

                    if it == args.Iteration: # record the final results
                        accs_all_exps[model_eval] += accs

                ''' visualize and save '''
                save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))
                image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                for ch in range(channel):
                    image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                image_syn_vis[image_syn_vis<0] = 0.0
                image_syn_vis[image_syn_vis>1] = 1.0
                save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.
                if args.ipc <= 50:
                    upsampled = torch.repeat_interleave(image_syn_vis, repeats=4, dim=2)
                    upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                    grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
                    wandb.log({"Reconstructed_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)




            ''' Train synthetic data '''
            model_i = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
            model_i.train()
            if args.load_debiased:
                model_i.load_state_dict(torch.load(args.debias_model_path)['state_dict'])
            for param in list(model_i.parameters()):
                param.requires_grad = False
            if args.load_biased:
                model_b = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
                model_b.train()
                model_b.load_state_dict(torch.load(args.bias_model_path)['state_dict'])
                for param in list(model_b.parameters()):
                    param.requires_grad = False

            embed_i = model_i.module.embed if torch.cuda.device_count() > 1 else model_i.embed # for GPU parallel
            if args.load_biased:
                embed_b = model_b.module.embed if torch.cuda.device_count() > 1 else model_b.embed # for GPU parallel

            loss_avg = 0

            ''' update synthetic data '''
            if 'BN' not in args.model: # for ConvNet
                loss = torch.tensor(0.0).to(args.device)
                if not args.shuffle_bias:
                    for c in range(num_classes):
                        # get real images, their true labels and their bias labels.
                        img_real, labels_real, bias_labels_real = get_images_and_labels(c, args.batch_real)
                        img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                        if args.dsa:
                            seed = int(time.time() * 1000) % 100000
                            img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                            img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                        # if 'MLP' in args.model or 'mlp' in args.model:
                        with torch.no_grad():
                            output_real_i = embed_i(img_real).detach()
                        output_syn_i = embed_i(img_syn)


                        # loss += torch.sum((torch.mean(output_real_i, dim=0) - torch.mean(output_syn_i, dim=0))**2)
                        # use the model to generate bias labels instead of using the ground truth label.
                        if args.use_kde:
                            # nothing needs to be done for KDE.
                            pass
                        else:
                            pass
                        
                        if args.use_kde:
                            with torch.no_grad():
                                sim_feature = model_c(img_real)
                                reblanced_weight = compute_kde(sim_feature, bias_labels_real, args.kernel_temp, args.kernel_h)
                                real_sum = torch.sum(torch.diag(reblanced_weight) @ output_real_i, dim=0)
                        else:
                            real_sum = torch.mean(output_real_i, dim=0)
                        loss += torch.sum((real_sum - torch.mean(output_syn_i, dim=0))**2)
            else: # for ConvNetBN
                exit()


            optimizer_img.zero_grad()
            loss.backward()
            optimizer_img.step()
            loss_avg += loss.item()


            loss_avg /= (num_classes)

            if it%10 == 0:
                print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg))

            if it == args.Iteration: # only record the final results
                data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                saved_path = os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc))
                torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, saved_path)
                print(saved_path)


    print('\n==================== Final Results ====================\n')
    for key in model_eval_pool:
        accs = accs_all_exps[key]
        print('Run %d experiments, train on %s, evaluate %d random %s, mean  = %.2f%%  std = %.2f%%'%(args.num_exp, args.model, len(accs), key, np.mean(accs)*100, np.std(accs)*100))



if __name__ == '__main__':
    main()

