import pickle

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

import tqdm

import absl.app
import absl.flags
from absl import logging


from vs_maml.misc_utils import AttrDict
from vs_maml.misc_utils import define_flags_with_default, set_random_seed, TensorBoardLogger, print_flags
from vs_maml.data import datasets

FLAGS = absl.flags.FLAGS

flags_def = define_flags_with_default(
    lr=1e-4,
    embedding_dim=256,
    clip_gradient=0.0,
    meta_train_tasks=5,
    inner_batch_size='10',
    outer_batch_size=256,
    weight_decay=0.0,
    n_steps=10000,
    test_interval=1000,
    test_batches=5,
    log_interval=20,
    n_training_tasks=0,
    seed=42,
    dataset='mnist',
    device='cuda',
    data_file='./rainbow_mnist.pkl',
    output_dir='/tmp/vs_maml',
)


def parse_variabe_batch_size(batch_sizes):
    if batch_sizes == '':
        return None
    return [int(x) for x in batch_sizes.split('-')]


def average_dict(dicts):
    keys = set(dicts[0].keys())
    sum_dict = {key: 0 for key in keys}
    for d in dicts:
        assert set(d.keys()) == keys
        for key in keys:
            sum_dict[key] += d[key]
    for key in keys:
        sum_dict[key] = sum_dict[key] / len(dicts)
    return sum_dict


class Flatten(nn.Module):

    def forward(self, input):
        return input.view(input.shape[0], -1)


class MnistFeature(nn.Module):

    def __init__(self, embedding_dim=256):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 32, 3),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 32, 3),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 32, 3),
            nn.LeakyReLU(0.2),
            Flatten(),
            nn.Linear(20 * 20 * 32, embedding_dim)
        )
        self.embedding = nn.Embedding(10, embedding_dim)

    def forward(self, x, y=None):
        features = self.net(x)
        if y is None:
            return features
        embeddings = self.embedding(y)
        return torch.cat([features, embeddings], dim=1)


class ImageNetFeatures(nn.Module):

    def __init__(self, embedding_dim=256):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.net = nn.Sequential(
            nn.Conv2d(6, 32, 3),
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 32, 3),
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 32, 3),
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 32, 3),
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2),
            Flatten(),
            nn.Linear(3 * 3 * 32, embedding_dim)
        )
        self.embedding = nn.Embedding(2, embedding_dim)

    def forward(self, x, y=None):
        features = self.net(x)
        if y is None:
            return features
        embeddings = self.embedding(y)
        return torch.cat([features, embeddings], dim=1)


class MetaLearner(nn.Module):

    def __init__(self, dataset, embedding_dim=256):
        super().__init__()
        self.embedding_dim = embedding_dim
        if dataset == 'mnist':
            self.feature_net = MnistFeature(embedding_dim)
            n_logits = 10
        else:
            self.feature_net = ImageNetFeatures(embedding_dim)
            n_logits = 2

        self.encoder = nn.GRU(2 * embedding_dim, embedding_dim, batch_first=True)
        self.decoder = nn.Linear(2 * embedding_dim, n_logits)

    def forward(self, pre_adapt_images, pre_adapt_labels, post_adapt_images):
        if pre_adapt_images.shape[0] == 0:
            code = post_adapt_images.new_zeros((post_adapt_images.shape[0], self.embedding_dim))
        else:
            pre_adapt_features = self.feature_net(pre_adapt_images, pre_adapt_labels).unsqueeze(0)
            _, code = self.encoder(pre_adapt_features)
            code = code.squeeze(0) * code.new_ones((post_adapt_images.shape[0], 1))
        post_adapt_features = self.feature_net(post_adapt_images)
        return self.decoder(torch.cat([code, post_adapt_features], dim=1))

    def forward_no_adapt(self, post_adapt_images):
        return self.forward(
            post_adapt_images.new_zeros((0, 0, 0, 0)),
            post_adapt_images.new_zeros(0),
            post_adapt_images,
        )


def compute_loss(logits, labels):
    return F.cross_entropy(logits, labels)


def compute_acc(logits, labels):
    return float(torch.mean((torch.argmax(logits, dim=1) == labels).float()).cpu())


def batch_to_tensor(pre_x, pre_y, post_x, post_y):
    pre_x = torch.from_numpy(pre_x).to(FLAGS.device).permute(0, 3, 1, 2)
    pre_y = torch.from_numpy(pre_y).to(FLAGS.device)
    post_x = torch.from_numpy(post_x).to(FLAGS.device).permute(0, 3, 1, 2)
    post_y = torch.from_numpy(post_y).to(FLAGS.device)
    return pre_x, pre_y, post_x, post_y



def main(_):
    set_random_seed(FLAGS.seed)
    tb_logger = TensorBoardLogger(FLAGS.output_dir)

    dataset = datasets[FLAGS.dataset](FLAGS.data_file, FLAGS.n_training_tasks)

    meta_learner = MetaLearner(FLAGS.dataset, FLAGS.embedding_dim)
    meta_learner.to(FLAGS.device)

    optimizer = torch.optim.Adam(meta_learner.parameters(), lr=FLAGS.lr)

    inner_batch_sizes = parse_variabe_batch_size(FLAGS.inner_batch_size)

    for train_step in range(FLAGS.n_steps):
        inner_batch_size = inner_batch_sizes[train_step % len(inner_batch_sizes)]
        data_batches = dataset.sample_multiple_tasks(
            'train', FLAGS.meta_train_tasks, inner_batch_size, FLAGS.outer_batch_size
        )

        loss = 0
        pre_adapt_acc = 0
        post_adapt_acc = 0
        for pre_x, pre_y, post_x, post_y in data_batches:
            pre_x, pre_y, post_x, post_y = batch_to_tensor(pre_x, pre_y, post_x, post_y)
            pre_adapt_logits = meta_learner.forward_no_adapt(post_x)
            pre_adapt_acc += compute_acc(pre_adapt_logits, post_y)
            post_adapt_logits = meta_learner(pre_x, pre_y, post_x)
            loss += compute_loss(post_adapt_logits, post_y)
            post_adapt_acc += compute_acc(post_adapt_logits, post_y)

        pre_adapt_acc /= FLAGS.meta_train_tasks
        post_adapt_acc /= FLAGS.meta_train_tasks
        loss /= FLAGS.meta_train_tasks

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_metrics = {
            '{}_shot_train/loss'.format(inner_batch_size): float(loss.cpu()),
            '{}_shot_train/pre_adapt_acc'.format(inner_batch_size): pre_adapt_acc,
            '{}_shot_train/post_adapt_acc'.format(inner_batch_size): post_adapt_acc,
            '{}_shot_train/acc_improvement'.format(inner_batch_size): post_adapt_acc - pre_adapt_acc,
        }

        exist_nan = not np.all([np.isfinite(v) for k, v in train_metrics.items()])

        if train_step % FLAGS.log_interval < len(inner_batch_sizes) or exist_nan:
            tb_logger.log_dict(train_step, train_metrics)

        if exist_nan:
            print(train_metrics)
            exit()

        if train_step % FLAGS.test_interval == 0:
            for inner_batch_size in inner_batch_sizes:
                test_metrics = []
                for task_id in range(dataset.n_val_tasks):
                    for _ in range(FLAGS.test_batches):
                        pre_x, pre_y, post_x, post_y = batch_to_tensor(
                            *dataset.sample_from_task(
                                'val', task_id, inner_batch_size,
                                FLAGS.outer_batch_size
                            )
                        )
                        pre_adapt_logits = meta_learner.forward_no_adapt(post_x)
                        pre_adapt_acc = compute_acc(pre_adapt_logits, post_y)
                        post_adapt_logits = meta_learner(pre_x, pre_y, post_x)
                        post_adapt_acc = compute_acc(post_adapt_logits, post_y)
                        test_metrics.append(
                            {'{}_shot_test/pre_adapt_acc'.format(inner_batch_size): pre_adapt_acc,
                             '{}_shot_test/post_adapt_acc'.format(inner_batch_size): post_adapt_acc,
                             '{}_shot_test/acc_improvement'.format(inner_batch_size): post_adapt_acc - pre_adapt_acc,
                            }
                        )
                average_test_metric = average_dict(test_metrics)
                tb_logger.log_dict(train_step, average_test_metric)
            tb_logger.flush()

    print('Training completed!')


if __name__ == '__main__':
    absl.app.run(main)