from stc_utils.divergence_measures.kl_div import calc_kl_divergence
from stc_utils.divergence_measures.mm_div import poe
from stc_utils.BaseExperiment import BaseExperiment
from stc_utils.BaseMMVae import BaseMMVae
from stc_utils.utils import Flatten, Unflatten
from stc_utils.save_samples import write_samples_img_to_file
from smuco.stc import DMCExperiment
from smuco.stc_utils.filehandling import set_exp_flags

from abc import ABC, abstractmethod
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.distributions as dist
import scipy.stats
import random
from torch.distributions.normal import Normal


class ReplayBuffer:
    def __init__(self, obs_shape, action_shape, capacity, num_views, batch_size, path_len, sub_seq_len, device):
        """
        views: shape (B, V, C, H, W)
        actions: shape (B, 1)
        rewards: shape (B, 1)
        """
        self.obs_shape = obs_shape
        self.capacity = capacity
        self.num_views = num_views
        self.batch_size = batch_size
        self.device = device

        self.views = np.empty(
            (capacity, num_views, *obs_shape), dtype=np.float32)
        self.actions = np.empty((capacity, *action_shape), dtype=np.float32)

        self.idx = 0
        self.last_save = 0  # used for partial save mechanism
        self._path_len = path_len  # used for sequential sample
        self.sub_seq_len = sub_seq_len
        self.full = False

    def add(self, views):
        assert len(views) == self.num_views, "Invalid view set"
        np.copyto(self.views[self.idx], views)
        # self.idx = (self.idx + 1) % self.capacity
        self.idx += 1
        if self.idx == self.capacity:
            self.idx = 0

    def sample(self) -> dict:
        """
        Returns:
        { m_i -> tensor with shape (B, C, H, W) }
        """
        idxs = np.random.randint(0, self.idx, size=self.batch_size)
        batch_views = self.views[idxs]
        _, V, _, _, _ = batch_views.shape
        m = {'m{}'.format(i): torch.Tensor(
            batch_views[:, i, :, :, :]) for i in range(V)}
        return m

    def _sample_sequential_idx(self, n, L):
        """
        Adopted from
        https://github.com/JmfanBU/DRIBO/blob/main/DRIBO/utils.py

        Returns:
        an index for a valid single chunk uniformly sampled from the memory
        e.x. [1,2,3, | 3,4,5, | 4,5,6, |6,7,8] n = 4 subsequences starting from [1,3,4,6] with length L = 3
        """
        idx = np.random.randint(
            0, self.capacity - L if self.full else self.idx - L, size=n
        )
        pos_in_path = idx - idx // self._path_len * self._path_len
        idx[pos_in_path > self._path_len - L] = idx[
            pos_in_path > self._path_len - L
        ] // self._path_len * self._path_len + L
        idxs = np.zeros((n, L), dtype=int)
        for i in range(n):
            idxs[i] = np.arange(idx[i], idx[i] + L)
        # flatten idxes into 1-d array
        return idxs.reshape(-1)

    def sample_subsequences(self, n, L):
        """
        Return sub-sequential views and actions
        """
        idxs = self._sample_sequential_idx(self.batch_size, self.sub_seq_len)
        _, V, C, H, W = self.views.shape
        views = torch.as_tensor(self.views[idxs], device=self.device).float().reshape(
            n, L, V, C, H, W)  # remain image dimension
        actions = torch.as_tensor(
            self.actions[idxs], device=self.device).reshape(n, L, -1)

        return views, actions

    def save(self, save_dir):
        if self.idx == self.last_save:
            print("Hit last_save:", self.idx)
            return
        path = os.path.join(
            save_dir, '{}_{}.pt'.format(self.last_save, self.idx))
        payload = self.views[self.last_save:self.idx]  # partial payload
        self.last_save = self.idx
        torch.save(payload, path)

    def load(self, save_dir):
        chunks = os.listdir(save_dir)
        chunks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
        for chunk in chunks:
            start, end = [int(x) for x in chunk.split(
                '.')[0].split('_')]  # exist extension
            path = os.path.join(save_dir, chunk)
            payload = torch.load(path)
            assert self.idx == start
            self.views[start:end] = payload
            self.idx = end


class Encoder(nn.Module):
    """
    Adopted from:
    https://www.cs.toronto.edu/~lczhang/360/lec/w05/autoencoder.html
    """

    def __init__(self, flags, act_dim):
        """
        Inputs:
        - act_dim: put conditional into forward process
        """
        super(Encoder, self).__init__()

        self.flags = flags
        self.act_dim = act_dim
        self.shared_encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            Flatten(),
            nn.Linear(2048, flags.style_dim + flags.class_dim),
            nn.ReLU(),
        )

        # content branch
        self.class_mu = nn.Linear(
            flags.style_dim + flags.class_dim + act_dim, flags.class_dim)
        self.class_logvar = nn.Linear(
            flags.style_dim + flags.class_dim + act_dim, flags.class_dim)
        # optional style branch
        if flags.factorized_representation:
            self.style_mu = nn.Linear(
                flags.style_dim + flags.class_dim + act_dim, flags.style_dim)
            self.style_logvar = nn.Linear(
                flags.style_dim + flags.class_dim + act_dim, flags.style_dim)

    def forward(self, x, conditional=None):
        h_prime = self.shared_encoder(x)
        if conditional is None:
            B = h_prime.shape[0]
            conditional = torch.zeros((B, self.act_dim))
        h = torch.cat([h_prime, conditional], dim=1)
        if self.flags.factorized_representation:
            return self.style_mu(h), self.style_logvar(h), self.class_mu(h), \
                self.class_logvar(h)
        else:
            return None, None, self.class_mu(h), self.class_logvar(h)


class Decoder(nn.Module):
    """
    Adopted from:
    https://www.cs.toronto.edu/~lczhang/360/lec/w05/autoencoder.html
    """

    def __init__(self, flags):
        super(Decoder, self).__init__()
        self.flags = flags
        self.decoder = nn.Sequential(
            nn.Linear(flags.style_dim + flags.class_dim, 2048),
            nn.ReLU(),
            Unflatten((128, 4, 4)),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2,
                               padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2,
                               padding=1, output_padding=1),
        )

    def forward(self, style_latent_space, class_latent_space):
        if self.flags.factorized_representation:
            z = torch.cat((style_latent_space, class_latent_space), dim=1)
        else:
            z = class_latent_space
        x_hat = self.decoder(z)
        return x_hat, torch.tensor(0.75).to(z.device)


class View(ABC):
    """
    Adopted from https://github.com/gr8joo/MVTCAE
    """

    def __init__(self, name, enc, dec, class_dim, style_dim, lhood_name):
        self.name = name
        self.encoder = enc
        self.decoder = dec
        self.class_dim = class_dim
        self.style_dim = style_dim
        self.likelihood_name = lhood_name
        self.likelihood = self.get_likelihood(lhood_name)

    def get_likelihood(self, name):
        if name == 'laplace':
            pz = dist.Laplace
        elif name == 'bernoulli':
            pz = dist.Bernoulli
        elif name == 'normal':
            pz = dist.Normal
        elif name == 'categorical':
            pz = dist.OneHotCategorical
        else:
            print('likelihood not implemented')
            pz = None
        return pz

    @abstractmethod
    def save_data(self, d, fn, args):
        pass

    @abstractmethod
    def plot_data(self, d):
        pass

    def calc_log_prob(self, out_dist, target, norm_value):
        log_prob = out_dist.log_prob(target).sum()
        mean_val_logprob = log_prob/norm_value
        return mean_val_logprob

    def save_networks(self, dir_checkpoints):
        torch.save(self.encoder.state_dict(), os.path.join(
            dir_checkpoints, 'enc_' + self.name))
        torch.save(self.decoder.state_dict(), os.path.join(
            dir_checkpoints, 'dec_' + self.name))


class DMCView(View):
    def __init__(self, name, enc, dec, class_dim, style_dim, lhood_name):
        super().__init__(name, enc, dec, class_dim, style_dim, lhood_name)
        self.data_size = torch.Size((3, 28, 28))
        self.gen_quality_eval = True
        self.file_suffix = '.png'
        # self.transform = transforms.Compose([transforms.ToTensor()])

    def save_data(self, d, fn, args):
        img_per_row = args['img_per_row']
        write_samples_img_to_file(d, fn, img_per_row)

    def plot_data(self, d):
        # out = self.transform(d.squeeze(0).cpu()).cuda().unsqueeze(0)
        # return out
        return d


class DMCExperiment(BaseExperiment):
    def __init__(self, cfg, replay_buffer=None):
        flags = cfg.mvtc
        super().__init__(flags)
        self.num_train = flags.train.num_train
        self.replay_buffer = replay_buffer

        self.num_views = flags.num_views
        self.plot_img_size = torch.Size((3, 28, 28))
        self.flags.num_features = 111

        self.views = self.set_views()
        self.subsets = self.set_subsets()  # create modality subsets
        self.dataset_train = None
        self.dataset_test = None
        # self.set_dataset()

        self.mm_vae = self.set_model()
        # self.clfs = self.set_clfs()
        self.optimizer = None
        self.rec_weights = self.set_rec_weights()
        self.style_weights = self.set_style_weights()

        self.paths_fid = self.set_paths_fid()

        # self.labels = ['digit']

    # def set_clfs(self):
    #     pass

    # def set_dataset(self):
    #     pass

    def set_views(self):
        mods = [
            DMCView(
                "v{}".format(m),
                Encoder(self.flags, self.flags.act_dim),
                Decoder(self.flags),
                self.flags.class_dim,
                self.flags.style_dim,
                self.flags.likelihood,
            ) for m in range(self.num_views)
        ]
        mods_dict = {m.name: m for m in mods}
        return mods_dict

    def set_model(self):
        model = BaseMMVae(self.flags, self.views, self.subsets)
        model = model.to(self.flags.device)
        return model

    def set_optimizer(self):
        # optimizer definition
        total_params = sum(p.numel() for p in self.mm_vae.parameters())
        params = list(self.mm_vae.parameters())
        print('num parameters: ' + str(total_params))
        optimizer = optim.Adam(params, lr=self.flags.train.initial_learning_rate,
                               betas=(self.flags.train.adam.beta_1, self.flags.train.adam.beta_2))
        self.optimizer = optimizer

    def set_rec_weights(self):
        rec_weights = dict()
        for k, m_key in enumerate(self.views.keys()):
            mod = self.views[m_key]
            numel_mod = mod.data_size.numel()
            rec_weights[mod.name] = 1.0
        return rec_weights

    def set_style_weights(self):
        weights = {"v%d" %
                   m: self.flags.loss.beta_style for m in range(self.num_views)}
        return weights

    # def get_test_samples(self, num_images=10):
    #     pass

    def mean_eval_metric(self, values):
        return np.mean(np.array(values))

    def get_prediction_from_attr(self, attr, index=None):
        pred = np.argmax(attr, axis=1).astype(int)
        return pred

    def eval_label(self, values, labels, index):
        pred = self.get_prediction_from_attr(values)
        return self.eval_metric(labels, pred)


def encode(model, views: list):
    mus, logvars = [], []
    for m_key, encoder in model.encoders.items():  # iterate over dictionary
        idx = int(m_key.strip().split('v')[-1])
        m = torch.Tensor(views[idx]).unsqueeze(0)
        _, _, mu, logvar = encoder(m)
        mus.append(mu)
        logvars.append(logvar)
    mus = torch.cat(mus)
    logvars = torch.cat(logvars)
    ivw_mu, ivw_logvar = model.ivw_fusion(mus, logvars)
    ivw_std = torch.sqrt(torch.exp(ivw_logvar))
    dist = Normal(ivw_mu, ivw_std)
    sample = dist.sample()
    return sample


def build_exp(cfg):
    V = cfg.replay_buffer.num_views
    SEED = cfg.seed
    FLAGS = cfg.mvtc
    FLAGS.multimodal.method = 'tc'
    FLAGS.multimodal.modality_ivw = True

    np.random.seed(SEED)
    torch.manual_seed(SEED)
    random.seed(SEED)

    # postprocess flags
    if FLAGS.div_weight_uniform_content is None:
        FLAGS.div_weight_uniform_content = 1 / (FLAGS.num_mods + 1)
    FLAGS.alpha_modalities = [FLAGS.div_weight_uniform_content]
    if FLAGS.div_weight is None:
        FLAGS.div_weight = 1 / (FLAGS.num_mods + 1)
    FLAGS.alpha_modalities.extend(
        [FLAGS.div_weight for _ in range(FLAGS.num_mods)])
    print("alpha_modalities:", FLAGS.alpha_modalities)
    set_exp_flags(FLAGS)
    exp = DMCExperiment(cfg)
    exp.set_optimizer()
    return exp


def calc_log_probs(exp, result, batch):
    mods = exp.views
    log_probs = dict()
    weighted_log_prob = 0.0
    for m, m_key in enumerate(mods.keys()):
        mod = mods[m_key]
        log_probs[mod.name] = -mod.calc_log_prob(result['rec'][mod.name],
                                                 batch[mod.name],
                                                 exp.flags.train.batch_size)
        weighted_log_prob += exp.rec_weights[mod.name]*log_probs[mod.name]
    return log_probs, torch.min(weighted_log_prob, torch.Tensor([random.random() * 50]))


def calc_klds(exp, result):
    latents = result['latents']['subsets']
    klds = dict()
    for m, key in enumerate(latents.keys()):
        mu, logvar = latents[key]
        klds[key] = calc_kl_divergence(mu, logvar,
                                       norm_value=exp.flags.train.batch_size)
    return klds


def calc_klds_style(exp, result):
    latents = result['latents']['views']
    klds = dict()
    for m, key in enumerate(latents.keys()):
        if key.endswith('style'):
            mu, logvar = latents[key]
            klds[key] = calc_kl_divergence(mu, logvar,
                                           norm_value=exp.flags.train.batch_size)
    return klds


def calc_style_kld(exp, klds):
    mods = exp.views
    style_weights = exp.style_weights
    weighted_klds = 0.0
    for m, m_key in enumerate(mods.keys()):
        weighted_klds += style_weights[m_key]*klds[m_key+'_style']
    return weighted_klds


def calc_klds_cvib(exp, result):
    latents = result['latents']['views']
    joint_mu, joint_logvar = result['latents']['joint']
    klds = dict()
    kld_losses = 0.0
    for m, key in enumerate(latents.keys()):
        if 'style' not in key:
            mu, logvar = latents[key]
            klds[key] = calc_kl_divergence(joint_mu, joint_logvar, mu, logvar,
                                           norm_value=exp.flags.train.batch_size)
            kld_losses += klds[key]
    # import pdb pdb.set_trace()
    return klds, kld_losses


def calc_entropy(data):
    """Calculates entropy of the passed `pd.Series`
    Input: data is dict type
    """
    p_data = torch.cat([x for x in data.values()]).sum(
        dim=[1, 2, 3])  # 1-dimensional tensor
    entropy = scipy.stats.entropy(p_data)  # get entropy from counts
    return entropy


def update_stc(cfg, exp):
    mm_vae = exp.mm_vae
    mm_vae.train()
    exp.mm_vae = mm_vae

    # batch = exp.replay_buffer.sample()
    batch_views, batch_acts = exp.replay_buffer.sample_subsequences(
        n=cfg.mvtc.train.batch_size,
        L=cfg.mvtc.train.sub_seq_len,
    )

    # basic_routine = basic_routine_epoch(exp, batch_views, batch_acts)
    # set up weights
    beta_style = exp.flags.loss.beta_style
    beta_content = exp.flags.loss.beta_content
    beta = exp.flags.loss.beta
    rec_weight = 1.0

    t = batch_views.mean(dim=1)  # average over temporal dimension
    # def normalize(x):
    #     return (x - x.mean()) / x.std()
    # t = normalize(t)
    batch_d = {"v{}".format(i): t[:, i, :]
               for i in range(batch_views.shape[2])}
    # assert batch["m0"].shape == batch_d1["m0"].shape, "Inconsistent shape"
    batch = batch_d
    conditional = batch_acts.mean(dim=2)

    mm_vae = exp.mm_vae
    mods = exp.views
    for k, m_key in enumerate(batch_d.keys()):
        batch_d[m_key] = Variable(batch_d[m_key]).to(exp.flags.device)
    results = mm_vae(batch, conditional)

    log_probs, weighted_log_prob = calc_log_probs(exp, results, batch)
    group_divergence = results['joint_divergence']

    klds = calc_klds(exp, results)
    if exp.flags.factorized_representation:
        klds_style = calc_klds_style(exp, results)

    n_views = exp.num_views
    tc_ratio = exp.flags.loss.tc_ratio

    klds_cvib_dict, klds_cvib = calc_klds_cvib(exp, results)
    rec_weight = (n_views - tc_ratio) / n_views
    cvib_weight = tc_ratio / n_views  # 0.3
    vib_weight = 1 - tc_ratio  # 0.1

    kld_weighted = cvib_weight * klds_cvib + vib_weight * group_divergence

    # results Z
    # batch O
    rec_loss = - calc_entropy(batch_d)
    ll_loss = - weighted_log_prob
    tc_loss = kld_weighted
    total_loss = rec_loss + ll_loss + tc_loss

    # backprop
    exp.optimizer.zero_grad()
    total_loss.backward()
    exp.optimizer.step()
