import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *

import copy
import random
from reparam_module import ReparamModule
import torch.utils.data
import warnings
import gc

from omegaconf import OmegaConf
from torch.utils.tensorboard import SummaryWriter

warnings.filterwarnings("ignore", category=DeprecationWarning)

def main(args):
    torch.set_num_threads(args.torch_num_threads)
    torch.random.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    args.dsa_param = ParamDiffAug()
    args.dsa = True if args.dsa == 'True' else False
    if args.dsa_strategy in ['none', 'None']:
        args.dsa = False
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    args.channel, args.im_size, args.num_classes, _, class_map, _, _, _, dst_train, _, testloader, _ = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args)
    eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist()
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

    accs_all_exps = dict() # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    if args.max_experts is not None and args.max_files is not None:
        args.total_experts = args.max_experts * args.max_files

    ae_config = OmegaConf.load(args.ae_config)
    ae_model = load_autoencoder_from_config(ae_config, args.ae_ckpt).to(args.device)

    args.latent_size = (args.im_size[0] // args.f, args.im_size[1] // args.f)
    if args.lpc is None:
        args.lpc = get_lpc(args)
    args.convnet_pooling = 'avgpooling' if args.latent_size[0] >= 2 ** (args.train_depth + 1) else 'none'

    args.save_path = os.path.join(args.save_path, 'MTT', f'{args.dataset}-{args.im_size[0]}', get_run_name(args))
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path, exist_ok=True)
    logger = Logger(os.path.join(args.save_path, 'log.txt'))
    tb_writer = SummaryWriter(args.save_path)

    get_latent = None
    if args.init == 'real':
        latent_all, label_all, indices_class = build_dataset(args, ae_model, dst_train, class_map, batch_size = 16 if args.latent_size[0] <= 64 else 4)
        def get_latent(c, n): # get random n latents from class c
            idx_shuffle = np.random.permutation(indices_class[c])[:n]
            return latent_all[idx_shuffle].to(args.device)

    if args.batch_syn is None:
        if args.num_batch_syn > args.lpc:
            print('nbsyn > lpc, will set nbsyn = lpc')
            args.num_batch_syn = args.lpc
        args.batch_syn = args.num_classes * args.lpc // args.num_batch_syn
    args.distributed = torch.cuda.device_count() > 1

    latent_syn, label_syn = prepare_latent(args, get_latent)
    optimizer_latent = get_optimizer_latent(args, latent_syn)  
    if args.init == 'real':
        del get_latent, latent_all, label_all, indices_class    # save memory  
    
    syn_lr = torch.tensor(args.lr_net, requires_grad = True).to(args.device)
    syn_lr = syn_lr.detach().to(args.device).requires_grad_(True)
    optimizer_lr = torch.optim.SGD([syn_lr], lr = args.lr_lr, momentum = 0.5)

    criterion = nn.CrossEntropyLoss().to(args.device)

    logger.log('%s training begins' % get_time())
    logger.log('Evaluation iteration pool: ' + print_eval_it_pool(eval_it_pool))
    logger.log('Evaluation model pool: ', model_eval_pool)
    logger.log(f'Dataset info: {args.dataset}, {args.channel} * {args.im_size[0]} * {args.im_size[1]}, {args.num_classes} classes')
    logger.log('Args: ' + str(args.__dict__))

    expert_dir = os.path.join(args.buffer_path, f'{args.dataset}-{args.res}-l{args.latent_size[0]}_{args.model}-d{args.train_depth}w{args.train_width}')
    logger.log("Expert Dir: {}".format(expert_dir))
    if args.load_all:
        buffer = []
        n = 0
        while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))):
            buffer = buffer + torch.load(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n)))
            n += 1
    else:
        expert_files = []
        n = 0
        while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n))):
            expert_files.append(os.path.join(expert_dir, "replay_buffer_{}.pt".format(n)))
            n += 1
        file_idx = 0
        expert_idx = 0
        random.shuffle(expert_files)
        if args.max_files is not None:
            expert_files = expert_files[:args.max_files]
        # logger.log("loading file {}".format(expert_files[file_idx]))
        buffer = torch.load(expert_files[file_idx])
        if args.max_experts is not None:
            buffer = buffer[:args.max_experts]
        random.shuffle(buffer)
    logger.log(f'Found {n} buffer files')

    best_acc = {"{}".format(m): 0 for m in model_eval_pool}
    best_std = {m: 0 for m in model_eval_pool}

    for it in range(0, args.Iteration+1):
        ''' Evaluate synthetic data '''
        if it in eval_it_pool:
            best_acc, best_std = eval_and_save(args, latent_syn, label_syn, ae_model, logger, testloader=testloader, model_eval_pool=model_eval_pool, it=it)
            for model_eval in best_acc.keys():
                tb_writer.add_scalar(f'best_acc/{model_eval}', best_acc[model_eval], it)
                tb_writer.add_scalar(f'best_std/{model_eval}', best_std[model_eval], it)
            tb_writer.flush()
        elif args.save_image_it and it % args.save_image_it == 0:
            save(args, latent_syn, ae_model, it = it)

        student_net = get_network(args.model, args.C, args.num_classes, args.latent_size, depth = args.train_depth, width = args.train_width, convnet_pooling = args.convnet_pooling).to(args.device) # get a random model

        if args.load_all:
            expert_trajectory = buffer[np.random.randint(0, len(buffer))]
        else:
            expert_trajectory = buffer[expert_idx]
            expert_idx += 1
            if expert_idx == len(buffer):
                expert_idx = 0
                file_idx += 1
                if file_idx == len(expert_files):
                    file_idx = 0
                    random.shuffle(expert_files)
                # logger.log("loading file {}".format(expert_files[file_idx]))
                if args.max_files != 1 and len(expert_files) > 1:
                    del buffer
                    buffer = torch.load(expert_files[file_idx])
                if args.max_experts is not None:
                    buffer = buffer[:args.max_experts]
                random.shuffle(buffer)

        start_epoch = np.random.randint(0, args.max_start_epoch)
        starting_params = expert_trajectory[start_epoch]

        target_params = expert_trajectory[start_epoch+args.expert_epochs]
        target_params = torch.cat([p.data.to(args.device).reshape(-1) for p in target_params], 0)
        student_params = [torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params], 0).requires_grad_(True)]
        starting_params = torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params], 0)

        student_net = ReparamModule(student_net)
        gradient_sum = torch.zeros(starting_params.shape).requires_grad_(False).to(args.device)
        param_dist = torch.tensor(0.0).to(args.device)
        param_dist += torch.nn.functional.mse_loss(starting_params, target_params, reduction="sum")

        if args.distributed:
            student_net = torch.nn.DataParallel(student_net)

        student_net.train()

        # image_syn = syn_images.detach()
        latent_syn_detached = latent_syn.detach()

        y_hat = label_syn
        x_list = []
        y_list = []
        indices_chunks = []
        indices_chunks_copy = []
        original_x_list = []
        gc.collect()

        syn_label_grad = torch.zeros(label_syn.shape).to(args.device).requires_grad_(False)
        syn_latent_grad = torch.zeros(latent_syn.shape).requires_grad_(False).to(args.device)

        for il in range(args.syn_steps):
            if not indices_chunks:
                indices = torch.randperm(len(latent_syn))
                indices_chunks = list(torch.split(indices, args.batch_syn))

            these_indices = indices_chunks.pop()
            indices_chunks_copy.append(these_indices.clone())

            x = latent_syn[these_indices]
            this_y = y_hat[these_indices]

            original_x_list.append(x)
            if args.dsa:
                x = DiffAugment(x, args.dsa_strategy, param=args.dsa_param)

            x_list.append(x.clone())
            y_list.append(this_y.clone())

            forward_params = student_params[-1]
            forward_params = copy.deepcopy(forward_params.detach()).requires_grad_(True)

            if args.distributed:
                forward_params_expanded = forward_params.unsqueeze(0).expand(torch.cuda.device_count(), -1)
            else:
                forward_params_expanded = forward_params

            x = student_net(x, flat_param=forward_params_expanded)

            ce_loss = criterion(x, this_y)

            grad = torch.autograd.grad(ce_loss, forward_params, create_graph=True, retain_graph=True)[0]
            student_params.append(forward_params - syn_lr.item() * grad.detach().clone())
            gradient_sum = gradient_sum + grad.detach().clone()


        for il in range(args.syn_steps):
            w = student_params[il]

            if args.distributed:
                w_expanded = w.unsqueeze(0).expand(torch.cuda.device_count(), -1)
            else:
                w_expanded = w

            output = student_net(x_list[il], flat_param=w_expanded)

            if args.batch_syn:
                ce_loss = criterion(output, y_list[il])
            else:
                ce_loss = criterion(output, y_hat)

            grad = torch.autograd.grad(ce_loss, w, create_graph=True, retain_graph=True)[0]

            # Square term gradients.
            square_term = syn_lr.item() ** 2 * (grad @ grad)
            single_term = 2 * syn_lr.item() * grad @ (
                        syn_lr.item() * (gradient_sum - grad.detach().clone()) - starting_params + target_params)

            per_batch_loss = (square_term + single_term) / param_dist
            gradients = torch.autograd.grad(per_batch_loss, original_x_list[il], retain_graph=False)[0]

            with torch.no_grad():
                syn_latent_grad[indices_chunks_copy[il]] += gradients

        # ---------end of computing input image gradients and learning rates--------------

        del w, output, ce_loss, grad, square_term, single_term, per_batch_loss, gradients, student_net, w_expanded, forward_params, forward_params_expanded

        optimizer_latent.zero_grad()
        optimizer_lr.zero_grad()

        syn_lr.requires_grad_(True)
        grand_loss = starting_params - syn_lr * gradient_sum - target_params
        grand_loss = grand_loss.dot(grand_loss)
        grand_loss = grand_loss / param_dist

        lr_grad = torch.autograd.grad(grand_loss, syn_lr)[0]
        syn_lr.grad = lr_grad

        optimizer_lr.step()
        optimizer_lr.zero_grad()

        latent_syn_detached.requires_grad_(True)
        latent_syn_detached.grad = syn_latent_grad.detach().clone()

        del syn_latent_grad
        del lr_grad

        for _ in student_params:
            del _
        for _ in x_list:
            del _
        for _ in y_list:
            del _

        torch.cuda.empty_cache()
        gc.collect()

        latent_syn.grad = latent_syn_detached.grad.detach().clone()

        optimizer_latent.step()
        optimizer_latent.zero_grad()

        if it % args.log_it == 0:
            logger.log('%s iter = %04d, loss = %.4f, syn_lr = %.6f' % (get_time(), it, grand_loss.item(), syn_lr.item()))

        # if it == args.Iteration: # only record the final results
        #     data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
        #     torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%dipc.pt'%(args.dataset, args.model, args.ipc)))

if __name__ == '__main__':
    import shared_args
    parser = shared_args.add_shared_args()
    parser.add_argument('--batch_syn', type=int, default=None, help='batch size for syn data')
    parser.add_argument('--buffer_path', type=str, default='./latent_mtt_buffer', help='buffer path')
    parser.add_argument('--load_all', action='store_true')
    parser.add_argument('--max_start_epoch', type=int, default=5)
    parser.add_argument('--max_files', type=int, default=None)
    parser.add_argument('--max_experts', type=int, default=None)
    parser.add_argument('--expert_epochs', type=int, default=3, help='how many expert epochs the target params are')
    parser.add_argument('--syn_steps', type=int, default=20, help='how many steps to take on synthetic data')

    parser.add_argument('--lr_lr', type=float, default=1e-06, help='learning rate learning rate')
    parser.add_argument('--num_batch_syn', type = int, default = 12, help='batch_syn = num_classes * lpc // num_batch_syn')
    args = parser.parse_args()

    main(args)

