import os
import sys
import fire
import time
import copy
import ml_collections
from tqdm import tqdm
from absl import logging

os.environ['JAX_PLATFORM_NAME'] = 'cpu'
sys.path.append("../..")

import numpy as np
import tensorflow as tf

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch._vmap_internals import vmap

from lib.dataset.dataloader import get_dataset, get_dataset_imagenet32

from lib_torch.utils import get_network, evaluate_synset, get_time, TensorDataset, ParamDiffAug
from lib_torch.translator import Translator

from clu import metric_writers
from lib_torch.reparam_module import ReparamModule
from learn2learn import clone_module


def get_config():
    config = ml_collections.ConfigDict()
    config.random_seed = 0
    config.train_log = 'train_log'
    config.train_img = 'train_img'
    config.resume = False

    config.img_size = None
    config.img_channels = None
    config.num_prototypes = None
    config.train_size = None

    config.dataset = ml_collections.ConfigDict()
    config.kernel = ml_collections.ConfigDict()
    config.online = ml_collections.ConfigDict()

    # Dataset
    config.dataset.name = 'cifar100'  # ['cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'tiny_imagenet']
    config.dataset.data_path = 'data/tensorflow_datasets'
    config.dataset.zca_path = 'data/zca'
    config.dataset.zca_reg = 0.1

    # online
    config.online.img_size = None
    config.online.img_channels = None
    config.online.optimizer = 'adam'
    config.online.learning_rate = 0.0003
    config.online.arch = 'conv'
    config.online.output = 'feat_fc'
    config.online.width = 128
    config.online.normalization = 'identity'

    # Kernel
    config.kernel.img_size = None
    config.kernel.img_channels = None
    config.kernel.num_prototypes = None
    config.kernel.train_size = None
    config.kernel.resume = config.resume
    config.kernel.optimizer = 'adam'
    config.kernel.learning_rate = 0.001
    config.kernel.batch_size = 1024
    config.kernel.eval_batch_size = 1000

    return config


@vmap
def lb_margin_th(logits):
    dim = logits.shape[-1]
    val, idx = torch.topk(logits, k=2)
    margin = torch.minimum(val[..., 0] - val[..., 1], torch.tensor(1 / dim, dtype=torch.float, device=logits.device))
    return -margin


class SynData(nn.Module):
    def __init__(self, x_init, y_init, learn_label=False):
        super(SynData, self).__init__()
        self.x_syn = nn.Parameter(x_init, requires_grad=True)
        self.y_syn = nn.Parameter(y_init, requires_grad=learn_label)

    def forward(self):
        return self.x_syn, self.y_syn

    def value(self):
        '''Return the synthetic images and labels. Used in deterministic parameterization of synthetic data'''
        return self.x_syn.detach(), self.y_syn.detach()


class InfiniteDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Initialize an iterator over the dataset.
        self.dataset_iterator = super().__iter__()

    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = next(self.dataset_iterator)
        except StopIteration:
            # Dataset exhausted, use a new fresh iterator.
            self.dataset_iterator = super().__iter__()
            batch = next(self.dataset_iterator)
        return batch


class PoolElement:
    def __init__(self, get_model, get_optimizer, get_scheduler, loss_fn, batch_size, max_online_updates, idx, device,
                 step=0):
        self.get_model = get_model
        self.get_optimizer = get_optimizer
        self.get_scheduler = get_scheduler
        self.loss_fn = loss_fn.to(device)
        self.batch_size = batch_size
        self.max_online_updates = max_online_updates
        self.idx = idx
        self.device = device
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.initialize()
        self.step = step

    def __call__(self, x, no_grad=False):
        self.model.eval()
        if no_grad:
            with torch.no_grad():
                return self.model(x)
        else:
            return self.model(x)

    def feature(self, x, no_grad=False, weight_grad=False):
        self.model.eval()
        if no_grad:
            with torch.no_grad():
                return self.model.embed(x)
        else:
            self.model.requires_grad_(weight_grad)
            return self.model.embed(x)

    def nfr(self, x_syn, y_syn, x_tar, reg=1e-4, weight_grad=False, use_flip=False):
        if use_flip:
            x_syn_flip = torch.flip(x_syn, dims=[-1])
            x_syn = torch.cat((x_syn, x_syn_flip), dim=0)
            y_syn = torch.cat((y_syn, y_syn), dim=0)

        feat_tar = self.feature(x_tar, no_grad=True)
        feat_syn = self.feature(x_syn, weight_grad=weight_grad)

        kss = torch.mm(feat_syn, feat_syn.t())
        kts = torch.mm(feat_tar, feat_syn.t())
        kss_reg = (kss + np.abs(reg) * torch.trace(kss) * torch.eye(kss.shape[0], device=kss.device) / kss.shape[0])
        pred = torch.mm(kts, torch.linalg.solve(kss_reg, y_syn))
        return pred

    def nfr_batch(self, x_syn, y_syn, x_tar, reg=1e-4, weight_grad=False, use_flip=False):
        if use_flip:
            x_syn_flip = torch.flip(x_syn, dims=[-1])
            x_syn = torch.cat((x_syn, x_syn_flip), dim=1)
            y_syn = torch.cat((y_syn, y_syn), dim=1)

        b, n_tar = x_tar.shape[:2]
        n_syn = x_syn.shape[1]
        feat_tar = self.feature(x_tar.flatten(0, 1), no_grad=True).unflatten(0, (b, n_tar))
        feat_syn = self.feature(x_syn.flatten(0, 1), weight_grad=weight_grad).unflatten(0, (b, n_syn))

        kss = torch.bmm(feat_syn, feat_syn.transpose(-1, -2))
        kts = torch.bmm(feat_tar, feat_syn.transpose(-1, -2))
        trace = kss.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1).unsqueeze(-1).unsqueeze(-1)
        eye = torch.eye(kss.shape[1], device=kss.device).unsqueeze(0)
        kss_reg = kss + reg * trace * eye / kss.shape[1]
        pred = torch.bmm(kts, torch.linalg.solve(kss_reg, y_syn))
        return pred

    def nfr_eval(self, feat_syn, y_syn, x_tar, kss_reg):
        feat_tar = self.feature(x_tar, no_grad=True)
        kts = torch.mm(feat_tar, feat_syn.t())
        pred = torch.mm(kts, torch.linalg.solve(kss_reg, y_syn))
        return pred

    def train_steps(self, x_syn, y_syn, steps=1):
        self.model.train()
        self.model.requires_grad_(True)
        for step in range(steps):
            x, y = self.get_batch(x_syn, y_syn)
            self.optimizer.zero_grad(set_to_none=True)
            output = self.model(x)
            loss = self.loss_fn(output, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
        self.check_for_reset(steps=steps)

    def evaluate_syn(self, x_syn, y_syn):
        pass

    def get_batch(self, xs, ys):
        if ys.shape[0] < self.batch_size:
            x, y = xs, ys
        else:
            sample_idx = np.random.choice(ys.shape[0], size=(self.batch_size,), replace=False)
            x, y = xs[sample_idx], ys[sample_idx]
        return x, y

    def initialize(self):
        self.model = self.get_model().to(self.device)
        self.optimizer = self.get_optimizer(self.model)
        self.scheduler = self.get_scheduler(self.optimizer)
        self.step = 0

    def check_for_reset(self, steps=1):
        self.step += steps
        if self.step >= self.max_online_updates:
            self.initialize()


def main(dataset_name, data_path=None, zca_path=None, train_log=None, arch='conv', width=128, depth=3,
         normalization='identity', learn_label=True, batch_size_val=1024, batch_size_real=1000, random_seed=0,
         num_train_steps=None, lr=1e-4, pretrain_path=None, suffix='', ipc_real=1300, mini_steps=1, mini_steps_val=300):
    config = get_config()
    config.random_seed = random_seed
    config.train_log = train_log if train_log else 'train_log'

    if not os.path.exists(train_log):
        os.makedirs(train_log)

    all_info = {100: {'ipc': [1, 2, 5, 10], 'train_ipc': [1, 2, 5, 10]}}
    all_ipc_val = [1, 10, 50]

    config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # --------------------------------------
    # Dataset
    # --------------------------------------
    config.dataset.data_path = data_path if data_path else 'data/tensorflow_datasets'
    config.dataset.zca_path = zca_path if zca_path else 'data/zca'

    config.dataset.name = dataset_name
    (x_train, y_train, _, _), _, _, _ = get_dataset_imagenet32(config.dataset)
    x_train = torch.tensor(x_train.transpose((0, 3, 1, 2)))
    num_classes = config.dataset.num_classes
    indices_class = [[] for _ in range(num_classes)]
    for i, lab in tqdm(enumerate(y_train)):
        indices_class[lab].append(i)
    indices_class_torch = []
    for c in range(num_classes):
        length = len(indices_class[c])
        print('class c = %d: %d real images' % (c, length))
        if length >= ipc_real:
            indices_class_torch.append(torch.tensor(indices_class[c])[:ipc_real])
    num_classes = len(indices_class_torch)
    indices_class_torch = torch.stack(indices_class_torch, dim=0).to(config.device)
    print('Number of Remaining Classes: %d' % num_classes)

    config.dataset.name = 'cifar10'
    (x_train_val, y_train_val, x_test_val, y_test_val), _, rev_preprocess_op_val, _ = get_dataset(
        config.dataset, return_raw=True)
    im_size = config.dataset.img_shape[0:2]
    channel = config.dataset.img_shape[-1]
    num_classes_val = config.dataset.num_classes

    x_train_val = torch.from_numpy(np.transpose(x_train_val, axes=[0, 3, 1, 2])).float()
    x_test_val = torch.from_numpy(np.transpose(x_test_val, axes=[0, 3, 1, 2])).float()
    y_train_val = torch.from_numpy(y_train_val).long()
    y_test_val = torch.from_numpy(y_test_val).long()

    indices_class_val = [[] for _ in range(num_classes_val)]
    for i, lab in tqdm(enumerate(y_train_val)):
        indices_class_val[lab].append(i)
    y_train_val = F.one_hot(y_train_val, num_classes_val) - 1 / num_classes_val
    for c in range(num_classes_val):
        print('[val] class c = %d: %d real images' % (c, len(indices_class_val[c])))

    y_scale_val = np.sqrt(num_classes_val / 10)
    y_test_val = F.one_hot(y_test_val, num_classes_val) - 1 / num_classes_val
    dst_test_val = TensorDataset(x_test_val, y_test_val)
    dst_train_val = TensorDataset(x_train_val, y_train_val)
    trainloader_val = DataLoader(dst_train_val, batch_size=1024, shuffle=False, num_workers=0)
    testloader_val = DataLoader(dst_test_val, batch_size=256, shuffle=False, num_workers=0)

    # --------------------------------------
    # Online
    # --------------------------------------
    config.online.arch = arch
    config.online.width = width
    config.online.depth = depth
    config.online.normalization = normalization
    config.online.img_size = config.dataset.img_shape[0]
    config.online.img_channels = config.dataset.img_shape[-1]

    # --------------------------------------
    # Logging
    # --------------------------------------
    steps_per_eval = 10000
    steps_per_save = 1000

    lr_net = config.online.learning_rate
    name = 'maml_mix_%s_' % suffix if suffix != '' else 'maml_mix_'
    for n_classes, info in all_info.items():
        name += '%dclass_train_ipc' % n_classes
        for ipc in info['train_ipc']:
            name += '_%d' % ipc
        name += '_test_ipc'
        for ipc in info['ipc']:
            name += '_%d' % ipc
    exp_name = os.path.join('{}'.format(dataset_name), name)

    work_dir = os.path.join(config.train_log, exp_name)
    ckpt_dir = os.path.join(work_dir, 'ckpt')
    writer = metric_writers.create_default_writer(logdir=work_dir)
    logging.info('work_dir: {}'.format(work_dir))

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    eval_it_pool = [300, 1000, 3000]
    model_eval_pool = [arch]

    accs_all_exps = dict()  # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    if normalization == 'batch':
        eval_normalization = 'identity'
    else:
        eval_normalization = normalization

    aug_strategy = 'flip_color_crop_rotate_scale_cutout'
    use_flip_val = True

    step_per_prototpyes = {10: 1000, 100: 2000, 200: 20000, 400: 5000, 500: 5000, 1000: 10000, 2000: 40000, 5000: 40000}

    args = ml_collections.ConfigDict()
    args.model = arch
    args.device = config.device
    args.lr_net = lr_net
    args.dsa = True
    args.dsa_strategy = aug_strategy
    args.dsa_param = ParamDiffAug()  # Todo: Implementation is slightly different from JAX Version.

    criterion = nn.MSELoss(reduction='none').to(config.device)

    get_model = lambda: get_network(arch, channel, 10, im_size, width=width, depth=depth,
                                    norm=normalization)
    get_optimizer = lambda m: torch.optim.Adam(m.parameters(), lr=args.lr_net, betas=(0.9, 0.999))
    get_scheduler = lambda o: torch.optim.lr_scheduler.CosineAnnealingLR(o, T_max=500, eta_min=args.lr_net * 0.01)
    net = PoolElement(get_model=get_model, get_optimizer=get_optimizer, get_scheduler=get_scheduler,
                      loss_fn=nn.MSELoss(), batch_size=500, max_online_updates=-1, idx=0,
                      device=config.device, step=0)

    trainloader_val = InfiniteDataLoader(dst_train_val, batch_size=batch_size_val, shuffle=True, num_workers=0)

    def get_images_val(cls, n):  # get random n images from class c
        np.random.seed(cls)
        idx_shuffle = np.random.permutation(indices_class_val[cls])[:n]
        return x_train_val[idx_shuffle], y_train_val[idx_shuffle]

    all_x_syn_eval = {}
    all_y_syn_eval = {}
    for ipc in all_ipc_val:

        num_prototypes = ipc * num_classes_val
        num_online_eval_updates = step_per_prototpyes[num_prototypes]
        args.epoch_eval_train = num_online_eval_updates
        args.batch_train = min(num_prototypes, 500)

        x_syn_val = []
        y_syn_val = []
        for c in range(num_classes_val):
            x, y = get_images_val(c, ipc)
            x_syn_val.append(x)
            y_syn_val.append(y)
        x_syn_val = torch.cat(x_syn_val, dim=0).to(config.device)
        y_syn_val = torch.cat(y_syn_val, dim=0).to(config.device) / y_scale_val
        syndata_val = SynData(x_syn_val, y_syn_val, learn_label=learn_label).to(config.device)

        if os.path.exists(os.path.join(train_log, 'cifar10_%dipc.pth' % ipc)):
            syndata_val.load_state_dict(torch.load(os.path.join(train_log, 'cifar10_%dipc.pth' % ipc)))
            x_syn_eval_, y_syn_eval_ = syndata_val.value()
            all_x_syn_eval[ipc] = x_syn_eval_
            all_y_syn_eval[ipc] = y_syn_eval_
            print('Load from Existing Checkpoint')
            continue

        x_syn_eval, y_syn_eval = syndata_val.value()
        all_x_syn_eval[ipc] = x_syn_eval
        all_y_syn_eval[ipc] = y_syn_eval
        torch.save(syndata_val.state_dict(), os.path.join(train_log, 'cifar10_%dipc.pth' % ipc))

    generator = Translator(channel, width, depth, 'relu', normalization, 'avgpooling').to(config.device)
    if pretrain_path is not None:
        ckpt = torch.load(pretrain_path)
        if ckpt.__contains__('state_dict'):
            ckpt = ckpt['state_dict']
        generator.load_state_dict(ckpt)
        print('Generator Load from %s' % pretrain_path)
    optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_train_steps, eta_min=lr * 0.1)

    best_val_acc = 0.0
    step_offset = 0

    if config.resume:
        ckpt_path = os.path.join(ckpt_dir, 'ckpt.pt')
        try:
            checkpoint = torch.load(ckpt_path)
            generator.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            step_offset = checkpoint['step_offset']
            best_val_acc = checkpoint['best_val_acc']
            logging.info('Load checkpoint from {}!'.format(ckpt_path))
            logging.info('step_offset: {}, best_val_acc: {}'.format(step_offset, best_val_acc))
        except FileNotFoundError:
            logging.info('No checkpoints found in {}!'.format(ckpt_dir))

    loss_sum = 0.0
    count = 0
    last_t = time.time()

    for it in range(step_offset + 1, num_train_steps + 1):
        ''' Train synthetic data '''
        optimizer.zero_grad(set_to_none=True)
        generator_task = clone_module(generator)
        param_task = torch.cat([p.reshape(-1) for p in generator_task.parameters()])
        generator_task = ReparamModule(generator_task)
        n_classes = np.random.choice(list(all_info.keys()))

        selected_classes = torch.randperm(num_classes, device=config.device)[:n_classes]
        selected_indices = indices_class_torch[selected_classes]
        ipc_real_batch = batch_size_real // n_classes
        selected_indices = selected_indices[:, torch.randperm(ipc_real, device=config.device)[:ipc_real_batch]].view(-1)
        x_t = x_train[selected_indices.cpu()].to(config.device).unflatten(0, (n_classes, ipc_real_batch))
        y_t = torch.arange(0, n_classes, device=config.device).unsqueeze(-1).repeat(1, ipc_real_batch)
        y_t = F.one_hot(y_t, n_classes) - 1 / n_classes

        for t in range(1, mini_steps + 1):
            ipc = np.random.choice(all_info[n_classes]['train_ipc'])
            x_s_ = x_t[:, :ipc].clone().flatten(0, 1)
            x_s = generator_task(x_s_, flat_param=param_task)
            y_s = y_t[:, :ipc].clone().flatten(0, 1)
            y_pred = net.nfr(x_s, y_s, x_t.flatten(0, 1), use_flip=n_classes * ipc <= 500)
            loss = criterion(y_pred, y_t.flatten(0, 1)).sum(dim=-1).mean()
            grad_task = torch.autograd.grad(loss, param_task, retain_graph=True, create_graph=True)[0]
            param_task = param_task - lr * grad_task
            
        ipc = np.random.choice(all_info[n_classes]['ipc'])
        x_s_ = x_t[:, :ipc].clone().flatten(0, 1)
        y_s = y_t[:, :ipc].clone().flatten(0, 1)
        x_s = generator_task(x_s_, flat_param=param_task)
        y_pred = net.nfr(x_s, y_s, x_t.flatten(0, 1), use_flip=n_classes * ipc <= 500)
        loss = criterion(y_pred, y_t.flatten(0, 1)).sum(dim=-1).mean()

        loss.backward()
        optimizer.step()

        scheduler.step()
        net.initialize()
        loss_sum += loss.item()
        count += 1

        if it % 100 == 0:
            x_syn_eval, y_syn_eval = copy.deepcopy(x_s.detach()), copy.deepcopy(y_s.detach())
            x_norm = torch.mean(torch.linalg.norm(x_syn_eval.reshape(
                x_syn_eval.shape[0], -1), ord=2, dim=-1)).cpu().numpy()
            y_norm = torch.mean(torch.linalg.norm(y_syn_eval.reshape(
                y_syn_eval.shape[0], -1), ord=2, dim=-1)).cpu().numpy()
            summary = {'train/loss': loss_sum / count,
                       'monitor/steps_per_second': 100 / (time.time() - last_t),
                       'monitor/learning_rate': scheduler.get_last_lr()[0],
                       'monitor/x_norm': x_norm,
                       'monitor/y_norm': y_norm,
                       'N_CLASSES': n_classes,
                       'IPC': ipc}
            writer.write_scalars(it, summary)

            last_t = time.time()
            loss_sum, count, count_nfr = 0.0, 0, 0

        ''' Evaluate synthetic data '''
        if it in eval_it_pool or it % steps_per_eval == 0:

            generator_task = copy.deepcopy(generator)
            optimizer_val = torch.optim.Adam(generator_task.parameters(), lr=lr)
            for t in range(mini_steps_val):
                ipc = np.random.choice(all_ipc_val)
                x_s = generator_task(all_x_syn_eval[ipc])
                y_s = all_y_syn_eval[ipc]
                x_t, y_t = next(trainloader_val)
                x_t = x_t.to(config.device)
                y_t = y_t.to(config.device)
                y_pred = net.nfr(x_s, y_s, x_t, use_flip=use_flip_val)
                optimizer_val.zero_grad()
                loss = criterion(y_pred, y_t).sum(dim=-1).mean()
                loss.backward()
                optimizer_val.step()
                net.initialize()
                if (t + 1) % 10 == 0:
                    print('Eval Adaptation Iter %d, Loss %f' % (t + 1, loss.item()))

            generator_task.eval()
            all_accs = []
            for ipc in all_ipc_val:
                num_prototypes = ipc * num_classes_val
                num_online_eval_updates = step_per_prototpyes[num_prototypes]
                args.epoch_eval_train = num_online_eval_updates
                args.batch_train = min(num_prototypes, 500)
                with torch.no_grad():
                    x_syn = generator_task(all_x_syn_eval[ipc])
                    y_syn = all_y_syn_eval[ipc]
                    x_norm = torch.mean(torch.linalg.norm(
                        x_syn.reshape(x_syn.shape[0], -1), ord=2, dim=-1)).cpu().numpy()
                    y_norm = torch.mean(torch.linalg.norm(
                        y_syn.reshape(y_syn.shape[0], -1), ord=2, dim=-1)).cpu().numpy()

                for model_eval in model_eval_pool:
                    print('----------\nEvaluation\nmodel_train = {}, model_eval = {}, iteration = {}'.format(
                        arch, model_eval, it))
                    accs = []
                    for it_eval in range(3):
                        net_eval = get_network(model_eval, channel, num_classes_val, im_size, width=width, depth=depth,
                                               norm=eval_normalization).to(
                            config.device)  # get a random model
                        x_syn_eval, y_syn_eval = copy.deepcopy(x_syn.detach()), copy.deepcopy(y_syn.detach())
                        _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, x_syn_eval, y_syn_eval,
                                                                 testloader_val, args)
                        accs.append(acc_test)
                    summary = {'eval/acc_mean': np.mean(accs), 'eval/acc_std': np.std(accs),
                               'monitor/x_norm': x_norm, 'monitor/y_norm': y_norm}
                    writer.write_scalars(it, summary)
                    all_accs += accs

            if float(np.mean(all_accs)) > best_val_acc:
                best_val_acc = float(np.mean(all_accs))
                ckpt_path = os.path.join(ckpt_dir, 'best_ckpt.pt')
                torch.save(dict(step_offset=it, state_dict=generator.state_dict(),
                                optimizer_state_dict=optimizer.state_dict(),
                                scheduler_state_dict=scheduler.state_dict()),
                           ckpt_path)
                logging.info('{} Save checkpoint to {}, best acc {}!'.format(get_time(), ckpt_path, best_val_acc))

            last_t = time.time()

        if it % steps_per_save == 0:
            ckpt_path = os.path.join(ckpt_dir, 'ckpt.pt')
            torch.save(dict(step_offset=it, best_val_acc=best_val_acc, state_dict=generator.state_dict(),
                            optimizer_state_dict=optimizer.state_dict(), scheduler_state_dict=scheduler.state_dict()),
                       ckpt_path)
            logging.info('Save checkpoint to {}!'.format(ckpt_path))


if __name__ == '__main__':
    tf.config.experimental.set_visible_devices([], 'GPU')
    logging.set_verbosity('info')
    fire.Fire(main)
