import os
import sys
import fire
import time
import copy
import ml_collections
from tqdm import tqdm
from absl import logging
from functools import partial

os.environ['JAX_PLATFORM_NAME'] = 'cpu'
sys.path.append("../..")

import numpy as np
import tensorflow as tf

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch._vmap_internals import vmap

from lib.dataset.dataloader import get_dataset

from lib_torch.utils import get_network, evaluate_synset, get_time, TensorDataset, ParamDiffAug, \
    save_torch_image
from lib_torch.translator import Translator

from clu import metric_writers


def get_config():
    config = ml_collections.ConfigDict()
    config.random_seed = 0
    config.train_log = 'train_log'
    config.train_img = 'train_img'
    config.resume = True

    config.img_size = None
    config.img_channels = None
    config.num_prototypes = None
    config.train_size = None

    config.dataset = ml_collections.ConfigDict()
    config.kernel = ml_collections.ConfigDict()
    config.online = ml_collections.ConfigDict()

    # Dataset
    config.dataset.name = 'cifar100'  # ['cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'tiny_imagenet']
    config.dataset.data_path = 'data/tensorflow_datasets'
    config.dataset.zca_path = 'data/zca'
    config.dataset.zca_reg = 0.1

    # online
    config.online.img_size = None
    config.online.img_channels = None
    config.online.optimizer = 'adam'
    config.online.learning_rate = 0.0003
    config.online.arch = 'conv'
    config.online.output = 'feat_fc'
    config.online.width = 128
    config.online.normalization = 'identity'

    # Kernel
    config.kernel.img_size = None
    config.kernel.img_channels = None
    config.kernel.num_prototypes = None
    config.kernel.train_size = None
    config.kernel.resume = config.resume
    config.kernel.optimizer = 'adam'
    config.kernel.learning_rate = 0.001
    config.kernel.batch_size = 1024
    config.kernel.eval_batch_size = 1000

    return config


@vmap
def lb_margin_th(logits):
    dim = logits.shape[-1]
    val, idx = torch.topk(logits, k=2)
    margin = torch.minimum(val[0] - val[1], torch.tensor(1 / dim, dtype=torch.float, device=logits.device))
    return -margin


class SynData(nn.Module):
    def __init__(self, x_init, y_init, learn_label=False):
        super(SynData, self).__init__()
        self.x_syn = nn.Parameter(x_init, requires_grad=True)
        self.y_syn = nn.Parameter(y_init, requires_grad=learn_label)

    def forward(self):
        return self.x_syn, self.y_syn

    def value(self):
        '''Return the synthetic images and labels. Used in deterministic parameterization of synthetic data'''
        return self.x_syn.detach(), self.y_syn.detach()


class InfiniteDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Initialize an iterator over the dataset.
        self.dataset_iterator = super().__iter__()

    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = next(self.dataset_iterator)
        except StopIteration:
            # Dataset exhausted, use a new fresh iterator.
            self.dataset_iterator = super().__iter__()
            batch = next(self.dataset_iterator)
        return batch


class PoolElement():
    def __init__(self, get_model, get_optimizer, get_scheduler, loss_fn, batch_size, max_online_updates, idx, device,
                 step=0):
        self.get_model = get_model
        self.get_optimizer = get_optimizer
        self.get_scheduler = get_scheduler
        self.loss_fn = loss_fn.to(device)
        self.batch_size = batch_size
        self.max_online_updates = max_online_updates
        self.idx = idx
        self.device = device
        self.model = None
        self.optimizer = None
        self.scheduler = None
        self.initialize()
        self.step = step

    def __call__(self, x, no_grad=False):
        self.model.eval()
        if no_grad:
            with torch.no_grad():
                return self.model(x)
        else:
            return self.model(x)

    def feature(self, x, no_grad=False, weight_grad=False):
        self.model.eval()
        if no_grad:
            with torch.no_grad():
                return self.model.embed(x)
        else:
            self.model.requires_grad_(weight_grad)
            return self.model.embed(x)

    def nfr(self, x_syn, y_syn, x_tar, reg=1e-6, weight_grad=False, use_flip=False):
        if use_flip:
            x_syn_flip = torch.flip(x_syn, dims=[-1])
            x_syn = torch.cat((x_syn, x_syn_flip), dim=0)
            y_syn = torch.cat((y_syn, y_syn), dim=0)

        feat_tar = self.feature(x_tar, no_grad=True)
        feat_syn = self.feature(x_syn, weight_grad=weight_grad)

        kss = torch.mm(feat_syn, feat_syn.t())
        kts = torch.mm(feat_tar, feat_syn.t())
        kss_reg = (kss + np.abs(reg) * torch.trace(kss) * torch.eye(kss.shape[0], device=kss.device) / kss.shape[0])
        pred = torch.mm(kts, torch.linalg.solve(kss_reg, y_syn))
        return pred

    def nfr_feat(self, x_syn, y_syn, feat_tar, reg=1e-6, weight_grad=False, use_flip=False):
        if use_flip:
            x_syn_flip = torch.flip(x_syn, dims=[-1])
            x_syn = torch.cat((x_syn, x_syn_flip), dim=0)
            y_syn = torch.cat((y_syn, y_syn), dim=0)

        feat_syn = self.feature(x_syn, weight_grad=weight_grad)
        idx = torch.randperm(feat_syn.shape[-1], device=feat_syn.device)[:feat_syn.shape[-1] // 2]
        feat_syn = feat_syn[..., idx]
        feat_tar = feat_tar[..., idx]

        kss = torch.mm(feat_syn, feat_syn.t())
        kts = torch.mm(feat_tar, feat_syn.t())
        kss_reg = (kss + np.abs(reg) * torch.trace(kss) * torch.eye(kss.shape[0], device=kss.device) / kss.shape[0])
        pred = torch.mm(kts, torch.linalg.solve(kss_reg, y_syn))
        return pred

    def nfr_eval(self, feat_syn, y_syn, x_tar, kss_reg):
        feat_tar = self.feature(x_tar, no_grad=True)
        kts = torch.mm(feat_tar, feat_syn.t())
        pred = torch.mm(kts, torch.linalg.solve(kss_reg, y_syn))
        return pred

    def train_steps(self, x_syn, y_syn, steps=1):
        self.model.train()
        self.model.requires_grad_(True)
        for step in range(steps):
            x, y = self.get_batch(x_syn, y_syn)
            self.optimizer.zero_grad(set_to_none=True)
            output = self.model(x)
            loss = self.loss_fn(output, y)
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
        self.check_for_reset(steps=steps)

    def evaluate_syn(self, x_syn, y_syn):
        pass

    def get_batch(self, xs, ys):
        if ys.shape[0] < self.batch_size:
            x, y = xs, ys
        else:
            sample_idx = np.random.choice(ys.shape[0], size=(self.batch_size,), replace=False)
            x, y = xs[sample_idx], ys[sample_idx]
        return x, y

    def initialize(self):
        self.model = self.get_model().to(self.device)
        self.optimizer = self.get_optimizer(self.model)
        self.scheduler = self.get_scheduler(self.optimizer)
        self.step = 0

    def check_for_reset(self, steps=1):
        self.step += steps
        if self.step >= self.max_online_updates:
            self.initialize()


def main(dataset_name, data_path=None, zca_path=None, train_log=None, train_img=None, save_image=True,
         arch='conv', width=128, depth=3, normalization='identity', learn_label=True,
         num_prototypes_per_class=10, random_seed=0, num_train_steps=100000, lr=1e-4,
         translator_path=None, exp_name='adaptation_cifar10', reg=1e-6):
    config = get_config()
    config.random_seed = random_seed
    config.train_log = train_log if train_log else 'train_log'
    config.train_img = train_img if train_img else 'train_img'

    if not os.path.exists(train_log):
        os.makedirs(train_log)
    if not os.path.exists(train_img):
        os.makedirs(train_img)

    config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # --------------------------------------
    # Dataset
    # --------------------------------------
    config.dataset.data_path = data_path if data_path else 'data/tensorflow_datasets'
    config.dataset.zca_path = zca_path if zca_path else 'data/zca'
    config.dataset.name = dataset_name

    (x_train, y_train, x_test, y_test), preprocess_op, rev_preprocess_op, proto_scale = get_dataset(
        config.dataset, return_raw=True)

    im_size = config.dataset.img_shape[0:2]
    channel = config.dataset.img_shape[-1]
    num_classes = config.dataset.num_classes
    class_names = config.dataset.class_names
    class_map = {x: x for x in range(num_classes)}

    x_train = torch.from_numpy(np.transpose(x_train, axes=[0, 3, 1, 2]))
    x_test = torch.from_numpy(np.transpose(x_test, axes=[0, 3, 1, 2]))
    y_train = torch.from_numpy(y_train)
    y_test = torch.from_numpy(y_test)

    dst_train = TensorDataset(x_train, y_train)

    num_prototypes = num_prototypes_per_class * config.dataset.num_classes
    config.kernel.num_prototypes = num_prototypes

    # --------------------------------------
    # Online
    # --------------------------------------
    config.online.arch = arch
    config.online.width = width
    config.online.depth = depth
    config.online.normalization = normalization
    config.online.img_size = config.dataset.img_shape[0]
    config.online.img_channels = config.dataset.img_shape[-1]

    # --------------------------------------
    # Logging
    # --------------------------------------
    steps_per_eval = 3000
    steps_per_save = 15000

    lr_syn = config.kernel.learning_rate
    lr_net = config.online.learning_rate
    exp_name = os.path.join('{}'.format(dataset_name), 'adapt_w_teacher', exp_name)

    image_dir = os.path.join(config.train_img, exp_name)
    work_dir = os.path.join(config.train_log, exp_name)
    ckpt_dir = os.path.join(work_dir, 'ckpt')
    writer = metric_writers.create_default_writer(logdir=work_dir)
    logging.info('work_dir: {}'.format(work_dir))

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    logging.info("image_dir: {}!".format(image_dir))

    if save_image:
        image_saver = partial(save_torch_image, num_classes=num_classes, class_names=class_names,
                              rev_preprocess_op=rev_preprocess_op, image_dir=image_dir, is_grey=False, save_img=True,
                              save_np=False)
    else:
        image_saver = None

    #eval_it_pool = [300, 1000, 3000]
    eval_it_pool = [1000]
    model_eval_pool = [arch]

    accs_all_exps = dict()  # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    if normalization == 'batch':
        eval_normalization = 'identity'
    else:
        eval_normalization = normalization

    if dataset_name in ['mnist', 'fashion_mnist']:
        use_flip = False
        aug_strategy = 'color_crop_rotate_scale_cutout'
    else:
        use_flip = True
        aug_strategy = 'flip_color_crop_rotate_scale_cutout'
    
    if dataset_name == 'tiny_imagenet':
        if num_prototypes_per_class == 1:
            use_flip = True
        elif num_prototypes_per_class == 10:
            use_flip = False
        else:
            raise ValueError(
                'Unsupported prototypes per class {} for {}'.format(num_prototypes_per_class, dataset_name))

    if dataset_name == 'imagenet_resized/64x64':
        use_flip = False

    step_per_prototpyes = {10: 1000, 20: 2000, 50: 2000, 100: 2000, 200: 2000, 300: 2000, 400: 2000, 500: 2000, 1000: 2000, 2000: 2000, 5000: 2000}

    args = ml_collections.ConfigDict()
    args.model = arch
    args.device = config.device
    args.lr_net = lr_net

    args.dsa = True
    args.dsa_strategy = aug_strategy
    args.dsa_param = ParamDiffAug()  # Todo: Implementation is slightly different from JAX Version.

    criterion = nn.MSELoss(reduction='none').to(config.device)
    # --------------------------------------
    # Organize the real dataset
    # --------------------------------------
    images_all = []
    labels_all = []
    indices_class = [[] for c in range(num_classes)]
    print("BUILDING DATASET")
    for i in tqdm(range(len(dst_train))):
        sample = dst_train[i]
        images_all.append(torch.unsqueeze(sample[0], dim=0))
        labels_all.append(class_map[torch.tensor(sample[1]).item()])

    for i, lab in tqdm(enumerate(labels_all)):
        indices_class[lab].append(i)
    images_all = torch.cat(images_all, dim=0).to("cpu")

    y_scale = np.sqrt(num_classes / 10)
    y_train = F.one_hot(y_train, num_classes=num_classes) - 1 / num_classes
    y_test = F.one_hot(y_test, num_classes=num_classes) - 1 / num_classes

    for c in range(num_classes):
        print('class c = %d: %d real images' % (c, len(indices_class[c])))

    def get_images(c, n):  # get random n images from class c
        #np.random.seed(n * len(indices_class) + c)
        np.random.seed(c)
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle]

    dst_train = TensorDataset(x_train, y_train)
    dst_test = TensorDataset(x_test, y_test)
    trainloader = DataLoader(dst_train, batch_size=1024, shuffle=False, num_workers=0)
    trainloader_rand = InfiniteDataLoader(dst_train, batch_size=1024, shuffle=True, num_workers=0)
    testloader = DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=0)

    get_model = lambda: get_network(arch, channel, num_classes, im_size, width=width, depth=depth, norm=normalization)
    get_extractor = get_model
    get_optimizer = lambda m: torch.optim.Adam(m.parameters(), lr=args.lr_net, betas=(0.9, 0.999))
    get_scheduler = lambda o: torch.optim.lr_scheduler.CosineAnnealingLR(o, T_max=500, eta_min=args.lr_net * 0.01)
    feats = []
    extractor = PoolElement(get_model=get_extractor, get_optimizer=get_optimizer, get_scheduler=get_scheduler,
                            loss_fn=nn.MSELoss(), batch_size=500, max_online_updates=10000000, idx=0,
                            device=config.device, step=0)
    for _ in range(100):
        x, y = next(trainloader_rand)
        extractor.train_steps(x.to(config.device), y.to(config.device))
    with torch.no_grad():
        feats = []
        y_tar = []
        for x_, y in trainloader:
            feat_ = extractor.feature(x_.to(config.device), no_grad=True)
            feats.append(feat_)
            y_tar.append(y.to(config.device))
        feats = torch.cat(feats, dim=0)
        y_tar = torch.cat(y_tar, dim=0)
        if feats.shape[0] < feats.shape[1]:
            cor = feats @ feats.T
            cor_reg = (cor + 1e-6 * torch.trace(cor) * torch.eye(cor.shape[0], device=cor.device) / cor.shape[0])
            w_t = feats.T @ torch.linalg.solve(cor_reg, y_tar)
        else:
            cor = feats.T @ feats
            cor_reg = (cor + 1e-6 * torch.trace(cor) * torch.eye(cor.shape[0], device=cor.device) / cor.shape[0])
            w_t = torch.linalg.inv(cor_reg) @ feats.T @ y_tar

    # --------------------------------------
    # Train
    # --------------------------------------
    all_syn_data = []
    for ipc in [10]:
        num_prototypes = ipc * num_classes
        num_online_eval_updates = step_per_prototpyes[num_prototypes]
        args.epoch_eval_train = num_online_eval_updates
        args.batch_train = min(num_prototypes, 500)
        best_val_acc = 0.0

        # --------------------------------------
        # Initialize the synthetic data
        # --------------------------------------
        y_syn = torch.tensor(np.array([np.ones(ipc) * i for i in range(num_classes)]),
                             dtype=torch.long,
                             device=config.device).view(-1)  # [0,0,0, 1,1,1, ..., 9,9,9]
        x_syn = torch.randn(size=(num_classes * ipc, channel, im_size[0], im_size[1]),
                            dtype=torch.float)

        for c in range(num_classes):
            x_syn.data[c * ipc:(c + 1) * ipc] = get_images(c, ipc).detach().data

        y_syn = (F.one_hot(y_syn, num_classes=num_classes) - 1 / num_classes) / y_scale

        if ipc >= 5:
            feat_ = extractor.feature(x_syn.to(config.device), no_grad=True)
            y_syn = feat_ @ w_t
        syndata = SynData(x_syn, y_syn, learn_label=learn_label).to(config.device)

        all_syn_data.append(syndata)

    train_syn_data = [all_syn_data[0]]

    net = PoolElement(get_model=get_model, get_optimizer=get_optimizer, get_scheduler=get_scheduler,
                      loss_fn=nn.MSELoss(), batch_size=500, max_online_updates=-1, idx=0,
                      device=config.device, step=0)
    translator = Translator(channel, config.online.width, config.online.depth, 'relu',
                            'batch', 'avgpooling').to(config.device)
    if translator_path is not None:
        translator.load_state_dict(torch.load(translator_path)['state_dict'])
    optimizer = torch.optim.Adam(translator.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_train_steps, eta_min=lr * 0.1)
    trainloader = InfiniteDataLoader(dst_train, batch_size=1024, shuffle=True, num_workers=0)

    total_time = 0.
    loss_sum = 0.0
    count = 0
    last_t = time.time()

    for it in range(1, num_train_steps + 1):
        ''' Train synthetic data '''
        cur_time = time.time()

        x_target, y_target = next(trainloader)
        x_target = x_target.to(config.device)
        y_target = y_target.to(config.device)
        syndata = train_syn_data[np.random.randint(0, len(train_syn_data))]
        x_syn_, y_syn = syndata.value()
        x_syn = translator(x_syn_)

        y_pred = net.nfr(x_syn, y_syn, x_target, reg=reg, use_flip=x_syn.shape[0] <= 500)
        loss = criterion(y_pred, y_target).sum(dim=-1).mean(0)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        scheduler.step()
        net.initialize()
        total_time += (time.time() - cur_time)
        loss_sum += loss.item() * x_target.shape[0]
        count += x_target.shape[0]

        if it % 100 == 0:
            x_syn_eval, y_syn_eval = copy.deepcopy(x_syn.detach()), copy.deepcopy(y_syn.detach())
            x_norm = torch.mean(torch.linalg.norm(x_syn_eval.view(x_syn.shape[0], -1), ord=2, dim=-1)).cpu().numpy()
            y_norm = torch.mean(torch.linalg.norm(y_syn_eval.view(y_syn.shape[0], -1), ord=2, dim=-1)).cpu().numpy()
            summary = {'train/loss': loss_sum / count,
                       'monitor/steps_per_second': count / 1024 / (time.time() - last_t),
                       'monitor/learning_rate': scheduler.get_last_lr()[0],
                       'monitor/x_norm': x_norm,
                       'monitor/y_norm': y_norm}
            writer.write_scalars(it, summary)

            last_t = time.time()
            loss_sum, count = 0.0, 0

        ''' Evaluate synthetic data '''
        if it in eval_it_pool or it % steps_per_eval == 0:
            translator.eval()
            for model_eval in model_eval_pool:
                print(
                    '----------\nEvaluation\nmodel_train = {}, model_eval = {}, iteration = {}'.format(arch, model_eval,
                                                                                                       it))
                for syndata in all_syn_data:
                    accs = []
                    for it_eval in range(1):
                        net_eval = get_network(model_eval, channel, num_classes, im_size, width=width, depth=depth,
                                               norm=eval_normalization).to(
                            config.device)  # get a random model
                        x_syn, y_syn = syndata.value()
                        with torch.no_grad():
                            x_syn = translator(x_syn)
                        x_syn_eval, y_syn_eval = copy.deepcopy(x_syn.detach()), copy.deepcopy(y_syn.detach())
                        _, acc_train, acc_test = evaluate_synset(it_eval, net_eval, x_syn_eval, y_syn_eval,
                                                                 testloader, args)
                        accs.append(acc_test)
                    summary = {'eval/acc_mean': np.mean(accs), 'eval/acc_std': np.std(accs)}
                    writer.write_scalars(it, summary)

            ''' visualize and save '''
            for idx, syn_data in enumerate(all_syn_data):
                x_syn_, y_syn = syn_data.value()
                x_proto, y_proto = copy.deepcopy(x_syn_.detach().cpu().numpy()), copy.deepcopy(
                    y_syn.detach().cpu().numpy())
                if image_saver:
                    image_saver(x_proto, y_proto, step=it, suffix='%d_original' % idx)
                with torch.no_grad():
                    x_syn = translator(x_syn_)
                x_proto, y_proto = copy.deepcopy(x_syn.detach().cpu().numpy()), copy.deepcopy(
                    y_syn.detach().cpu().numpy())
                if image_saver:
                    image_saver(x_proto, y_proto, step=it, suffix='%d' % idx)
            translator.train()
            last_t = time.time()
            print('Adaptation Time for IPC %d: %f' % (ipc, total_time))

        if it % steps_per_save == 0:
            ckpt_path = os.path.join(ckpt_dir, 'ckpt_translator_ipc%d.pt' % ipc)
            torch.save(dict(step_offset=it, best_val_acc=best_val_acc, condenser_state_dict=translator.state_dict(),
                            optimizer_state_dict=optimizer.state_dict(), scheduler_state_dict=scheduler.state_dict()),
                       ckpt_path)
            logging.info('Save checkpoint to {}!'.format(ckpt_path))


if __name__ == '__main__':
    tf.config.experimental.set_visible_devices([], 'GPU')
    logging.set_verbosity('info')
    fire.Fire(main)
