import math

import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce, einsum, unpack
from copy import deepcopy
from hydra.utils import instantiate

from src.models.basemodel import BaseLightningModel

# def skipped(latents: torch.Tensor, indices: torch.Tensor):
#     '''
#     slice indices from latents
#     '''
#     indices = (indices / 2).floor().to(torch.long)
#     return latents[indices]


class SSM(BaseLightningModel):
    save_num_imgs = 8
    save_seq_len = 8
    zproj_coef = 0.1

    @property
    def burnin(self):
        return self.evolver.burnin

    def __init__(self, encoder, decoder, evolver, 
                 reversal=True, skip_weight=0.0, entropy_weight=0.0, dec_sampling_ratio=1.0,
                 dropslot_ratio=0.0,
                 l1=0.0, l2=0.0, zproj=False):
        super().__init__()

        self.encoder, self.decoder = encoder, decoder
        self.evolver = evolver

        self.zproj = zproj
        self.zprojector = nn.Linear(encoder.dim, decoder.dim) if zproj else nn.Identity()

        self.dim_reduction = \
            nn.Linear(encoder.dim, decoder.dim) if encoder.dim != decoder.dim else nn.Identity()
        
        # self.matching = HungarianLayer() if use_hungarian else nn.Identity()
        # self.matching = MatchingAttention(enc_embed_dim, num_heads=1, simple_mixing=True)
        # self.alignment = nn.Identity()

        self.reversal = reversal
        self.skip_weight = skip_weight
        self.entropy_weight = entropy_weight
        self.dropslot_ratio = dropslot_ratio
        self.l1 = l1
        self.l2 = l2

        assert 0 < dec_sampling_ratio <= 1, 'dec_sampling_ratio should be in (0, 1]'
        self.dec_sampling_ratio = dec_sampling_ratio

    def encode(self, frames):
        slots = self.encoder(frames)
        # slots = self.alignment(slots)
        return slots

    def decode(self, slots):
        slots = self.dim_reduction(slots)
        images = self.decoder(slots)
        return images

    def update(self, config):
        '''
        Update model architecture for fine-tuning
        '''
        self.evolver.alignment = deepcopy(self.evolver.ap)
        # self.evolver.alignment.alignment_only = True
        self.evolver.alignment.alignment_only = False
        self.evolver.burnin = config.finetune.burnin
        self.use_pixel_loss = config.finetune.use_pixel_loss
        self.use_delta_loss = config.finetune.use_delta_loss

        controller = instantiate(config.finetune.new_modules)['controller']   # nn.ModuleDict
        self.evolver.ap = controller
        
        self.loss = self.loss_finetune
        self._validation_step = self._validation_step_finetune

    def forward(self, frames, actions=None, seq_indices=None, reversal=False):

        latents = self.encode(frames)

        # Forward and reverse time evolution
        # Note that pred_latents contain [latent_t0, ..., latent_T],
        # whereas rev_pred_latents contain [latent_t0-1, ..., latent_T-1] where t0 is the initial frame after burn-in
        # with torch.cuda.amp.autocast(enabled=False):
        pred_latents, rev_pred_latents, pred_latents_skipped, deltas = self.evolver(latents, actions)

        self.loss_zproj = F.mse_loss(latents[self.burnin + 1:].detach(), 
                                     self.zprojector(pred_latents.detach()))
        # self.loss_zproj = F.mse_loss(latents[self.burnin + 1:].detach(), 
        #                              self.zprojector(pred_latents))

        if seq_indices is not None:
            pred_latents = pred_latents[seq_indices]
            rev_pred_latents = rev_pred_latents[seq_indices]
            pred_latents_skipped = pred_latents_skipped[seq_indices]
        
        preds = self.decode(pred_latents)

        # # Compute contrastive loss
        # if self.beta > 0:
        #     zhat_onestep = self.evolver.time_evolve(latents[start:end], deltas)
        #     z_onestep = latents[start + 1 : end + 1]
            
        #     zhat_onestep, z_onestep = self.to_dec(zhat_onestep), self.to_dec(z_onestep)
        #     zhat_onestep, z_onestep, _ = fold(zhat_onestep, z_onestep)
        #     self.loss_contrastive = nt_xent_loss(zhat_onestep, z_onestep, temperature=0.1)

        if not reversal:
            return deltas, preds, None, None

        rev_preds = self.decode(rev_pred_latents)
        preds_skipped = self.decode(pred_latents_skipped)
        return deltas, preds, rev_preds, preds_skipped

    def rollout(self, latents, actions):
        '''
        Args:
            latents: torch.Tensor, shape=[seq_len, batch, slot, dim]
            actions: torch.Tensor, shape=[seq_len, batch, num_actions]
        
        Returns:
            preds: torch.Tensor, shape=[seq_len + 1, batch, slot, img_dim]
        '''

        seq_len = actions.shape[0]
        burnin = self.burnin
        num_steps = seq_len - burnin
        start = burnin - 1

        latents = self.evolver.actionable(latents)

        latents, signal_deltas = self.evolver.alignment(latents)
        init_latent = latents[start]

        # create pred_latents = [Z'_0, ..., Z'_t0-1, Z_t0] where Z' is rolled out from Z_t0
        rev_cum_deltas = - signal_deltas[:start].flip(0).cumsum(dim=0).flip(0)
        pred_latents = self.evolver.time_evolve(init_latent, rev_cum_deltas)
        pred_latents = torch.cat([pred_latents, init_latent.unsqueeze(0)], dim=0)
        
        # deltas = [delta_t0, ..., delta_T-1]
        deltas = torch.Tensor().to(latents.device)  # empty tensor
        for step in range(num_steps):
            _, _deltas = self.evolver.ap(pred_latents, 
                                         actions[ : burnin + step])
            deltas = torch.cat([deltas, _deltas[-1:]], dim=0)

            cum_delta = deltas.sum(dim=0, keepdim=True)
            ### for debug
            # cum_delta = signal_deltas[start : start + step + 1].sum(dim=0, keepdim=True)

            pred_latent = self.evolver.time_evolve(init_latent, cum_delta)
            pred_latents = torch.cat([pred_latents, pred_latent], dim=0)
        
        self.loss_delta = F.mse_loss(signal_deltas[start:].cumsum(dim=0), 
                                     deltas.cumsum(dim=0))
        
        pred_latents = self.evolver.deactionable(pred_latents)
        return pred_latents

    def loss_finetune(self, batch):
        fullseq, actions, _ = batch
        out_len = len(fullseq) - self.burnin 
        
        if self.dec_sampling_ratio < 1 and self.training:
            num_samples = math.ceil(out_len * self.dec_sampling_ratio)
            seq_indices = torch.randperm(out_len)[:num_samples]
        else:
            seq_indices = torch.arange(out_len)
        
        latents = self.encode(fullseq)
        pred_latents = self.rollout(latents, actions)   # t = 0, ..., burnin - 1, burnin, ..., T
        
        pred_latents = pred_latents[self.burnin:][seq_indices]
        preds = self.decode(pred_latents)

        forward_target = fullseq[self.burnin:][seq_indices]

        loss_rollout = F.mse_loss(preds.compose(), forward_target)

        loss_total = 0
        loss_total += float(self.use_pixel_loss) * loss_rollout
        loss_total += float(self.use_delta_loss) * self.loss_delta

        loss = dict(
            total = loss_total, 
            reconst = loss_rollout,
            delta = self.loss_delta,
        )
        
        return loss, [forward_target, preds.compose()]

    def loss(self, batch):
        fullseq, actions, _ = batch
        out_len = len(fullseq) - self.burnin - 1
        
        if self.dec_sampling_ratio < 1 and self.training:
            num_samples = math.ceil(out_len * self.dec_sampling_ratio)
            seq_indices = torch.randperm(out_len)[:num_samples]
        else:
            seq_indices = torch.arange(out_len)

        deltas, pred, rev_pred, pred_skipped = self(fullseq, actions, seq_indices, reversal=self.reversal)

        forward_target = fullseq[self.burnin + 1: ]
        forward_target = forward_target[seq_indices]
        loss_reconst = F.mse_loss(pred.compose(), forward_target)

        if self.reversal:
            reversal_target = fullseq[self.burnin : -1]
            reversal_target = reversal_target[seq_indices]
            loss_revsersal = F.mse_loss(rev_pred.compose(), reversal_target)
        else:
            loss_revsersal = 0

        if self.skip_weight > 0:
            loss_skipped = F.mse_loss(pred_skipped.compose(), forward_target)
        else:
            loss_skipped = 0

        prob = pred.mask.mean(dim=[-1, -2])
        # prob = pred.mask
        entropy = - (prob * prob.log()).mean()

        loss_total = 0
        loss_total += loss_reconst
        loss_total += float(self.reversal) * loss_revsersal
        loss_total += self.zproj_coef * self.loss_zproj
        if self.skip_weight > 0:
            loss_total += self.skip_weight * loss_skipped
        loss_total += self.l1 * deltas.abs().mean()
        # loss_total += self.l1 * (deltas[:-1] - deltas[1:]).abs().mean()
        loss_total += self.l2 * deltas.pow(2).mean()
        loss_total += self.entropy_weight * entropy

        # loss_total += self.beta * self.loss_contrastive
        
        loss = dict(total=loss_total, 
                    reconst=loss_reconst,
                    reversal=loss_revsersal,
                    # skipped=loss_skipped,
                    entropy=entropy,
                    zproj=self.loss_zproj,
                    )
        # return loss, (fullseq[:1], fullseq[-1:], pred)
        return loss, [forward_target, pred.compose(), reversal_target, rev_pred.compose()]
    

    def _validation_step(self, batch, is_first_batch, split):
        super()._validation_step(batch, is_first_batch, split)
        if not is_first_batch:
            return None

        fullseq, actions, _ = batch
        _, slotwise_preds, _, _ = self(fullseq, actions, reversal=False)
        slotwise_preds = slotwise_preds.compose(keepslot=True)
        
        num_slots = slotwise_preds.shape[2]
        slotwise_preds = rearrange(slotwise_preds[:self.save_seq_len, :self.save_num_imgs], 
                                   't b s ... -> (b t) s ...')
        self.save_image(slotwise_preds, split, self.current_epoch, 'slotwise', 
                        num_imgs=num_slots)

    def _validation_step_finetune(self, batch, is_first_batch, split):
        super()._validation_step(batch, is_first_batch, split)
        if not is_first_batch:
            return None

        fullseq, actions, _ = batch
        latents = self.encode(fullseq)

        end = self.burnin + self.save_seq_len
        pred_latents = self.rollout(latents[:end], actions[:end])
        slotwise_preds = self.decode(pred_latents)
        slotwise_preds = slotwise_preds.compose(keepslot=True)
        
        num_slots = slotwise_preds.shape[2]
        slotwise_preds = rearrange(slotwise_preds[:self.save_seq_len, :self.save_num_imgs], 
                                   't b s ... -> (b t) s ...')
        self.save_image(slotwise_preds, split, self.current_epoch, 'slotwise', 
                        num_imgs=num_slots)