import os
import wandb
import numpy as np

import torch
import torchvision.utils as vutils

from nerv.training import BaseMethod

from slotformer.base_slots.method import SAViMethod, STEVEMethod, \
    to_rgb_from_tensor

from gym.vector import AsyncVectorEnv

from transformers import T5Tokenizer, TFT5EncoderModel
from sentence_transformers import SentenceTransformer

from language_table.environments import blocks
from language_table.environments import language_table
from language_table.environments.rewards import block2block

import tensorflow_datasets as tfds

from ..base_slots.datasets.utils import BaseTransforms

import time


def denormalize_act(action, min_val=-0.1, max_val=0.1):
    action = (action + 1) / 2 * (max_val - min_val) + min_val
    return action
    
def build_method(**kwargs):
    params = kwargs['params']
    if params.model in ['SlotFormer', 'SingleStepSlotFormer', 'LSlotFormer']:
        return SlotFormerMethod(**kwargs)
    else:
        raise NotImplementedError(f'{params.model} method is not implemented.')




class SlotFormerMethod(SAViMethod):

    def _training_step_start(self):
        """Things to do at the beginning of every training step."""
        super()._training_step_start()

        if not hasattr(self.params, 'use_loss_decay'):
            return

        # decay the temporal weighting linearly
        if not self.params.use_loss_decay:
            self.model.module.loss_decay_factor = 1.
            return

        cur_steps = self.it
        total_steps = self.params.max_epochs * len(self.train_loader)
        decay_steps = self.params.loss_decay_pct * total_steps

        if cur_steps >= decay_steps:
            self.model.module.loss_decay_factor = 1.
            return

        # increase tau linearly from 0.01 to 1
        self.model.module.loss_decay_factor = \
            0.01 + cur_steps / decay_steps * 0.99

    def _log_train(self, out_dict):
        """Log statistics in training to wandb."""
        super()._log_train(out_dict)

        if self.local_rank != 0 or (self.epoch_it + 1) % self.print_iter != 0:
            return

        if not hasattr(self.params, 'use_loss_decay'):
            return

        # also log the loss_decay_factor
        x = self.model.module.loss_decay_factor
        wandb.log({'train/loss_decay_factor': x}, step=self.it)

    def _compare_videos(self, img, recon_combined, rollout_combined):
        """Stack 3 videos to compare them."""
        # pause the 1st frame if on PHYRE
        if 'phyre' in self.params.dataset.lower():
            img, recon_combined, rollout_combined = [
                self._pause_frame(x)
                for x in [img, recon_combined, rollout_combined]
            ]
        # pad to the length of rollout video
        T = rollout_combined.shape[0]
        img = self._pad_frame(img, T)
        recon_combined = self._pad_frame(recon_combined, T)
        out = to_rgb_from_tensor(
            torch.stack(
                [
                    img,  # original images
                    recon_combined,  # reconstructions
                    rollout_combined,  # rollouts
                ],
                dim=1,
            ))  # [T, 3, 3, H, W]
        save_video = torch.stack([
            vutils.make_grid(
                out[i].cpu(),
                nrow=out.shape[1],
                # pad white if using black background
                pad_value=1 if self.params.get('reverse_color', False) else 0,
            ) for i in range(img.shape[0])
        ])  # [T, 3, H, 3*W]
        return save_video

    def _read_video_and_slots(self, dst, idx):
        """Read the video and slots from the dataset."""
        # PHYRE
        if 'phyre' in self.params.dataset.lower():
            # read video
            data_dict = dst.get_video(idx, video_len=self.params.video_len)
            video = data_dict['video']
            # read slots
            slots = dst._read_slots(
                data_dict['data_idx'],
                video_len=self.params.video_len,
            )['slots']  # [T, N, C]
            slots = torch.from_numpy(slots).float().to(self.device)
        # OBJ3D, CLEVRER, Physion, langtable
        else:
            # read video
            video = dst.get_video(idx)['video']
            # read slots
            video_path = dst.files[idx]
            slots = dst.video_slots[os.path.basename(video_path)]  # [T, N, C]
            if self.params.frame_offset > 1:
                slots = np.ascontiguousarray(slots[::self.params.frame_offset])
            slots = torch.from_numpy(slots).float().to(self.device)
        T = min(video.shape[0], slots.shape[0])
        # video: [T, 3, H, W], slots: [T, N, C]
        return video[:T], slots[:T]

    @torch.no_grad()
    def _make_action_slot_grid(self, recon_combined, recons, masks):
        """Make a video of grid images showing slot decomposition."""
        scale = 0.
        out = to_rgb_from_tensor(
            torch.cat(
                [
                    recon_combined.unsqueeze(1),  # reconstructions
                    recons * masks + (1. - masks) * scale,  # each slot
                ],
                dim=1,
            ))  # [T, num_slots+2, 3, H, W]
        # stack the slot decomposition in all frames to a video
        save_video = torch.stack([
            vutils.make_grid(
                out[i].cpu(),
                nrow=out.shape[1],
                pad_value=1. - scale,
            ) for i in range(recons.shape[0])
        ])  # [T, 3, H, (num_slots+2)*W]
        return save_video

    @torch.no_grad()
    def validation_epoch(self, model, san_check_step=-1, sample_video=True):
        """Validate one epoch.

        We aggregate the avg of all statistics and only log once.
        """
        
        save_use_img_recon_loss = model.use_img_recon_loss
        model.use_img_recon_loss = sample_video  # do img_recon_loss eval
        save_loss_decay_factor = model.loss_decay_factor
        model.loss_decay_factor = 1.  # compute loss in normal scale
        BaseMethod.validation_epoch(self, model, san_check_step=san_check_step)
        model.use_img_recon_loss = save_use_img_recon_loss
        model.loss_decay_factor = save_loss_decay_factor

        if self.local_rank != 0:
            return
        # visualization after every epoch
        if sample_video:
            self._sample_video(model)


    @torch.no_grad()
    def _sample_video(self, model):
        """model is a simple nn.Module, not warpped in e.g. DataParallel."""
        model.eval()
        dst = self.val_loader.dataset
        sampled_idx = self._get_sample_idx(self.params.n_samples, dst)
        results, rollout_results, compare_results = [], [], []
        inst_table = wandb.Table(columns=["idx", "inst"])
        for i in sampled_idx:
            video, slots = self._read_video_and_slots(dst, i.item())
            inst = dst._read_insts(i.item()).to(self.device).unsqueeze(0)
            T = video.shape[0]
            # reconstruct gt_slots as sanity-check
            # i.e. if the pre-trained weights are loaded correctly
            recon_combined, recons, masks, _ = model.decode(slots)
            img = video.type_as(recon_combined)
            save_video = self._make_video_grid(img, recon_combined, recons,
                                               masks)
            results.append(save_video)
            # rollout
            past_steps = self.params.input_frames
            past_slots = slots[:past_steps][None]  # [1, t, N, C]
            out_dict = model.rollout(
                past_slots, inst, T - past_steps, decode=True, with_gt=True)
            out_dict = {k: v[0] for k, v in out_dict.items()}
            rollout_combined, recons, masks = out_dict['recon_combined'], \
                out_dict['recons'], out_dict['masks']
            img = video.type_as(rollout_combined)
            pred_video = self._make_video_grid(img, rollout_combined, recons,
                                               masks)
            rollout_results.append(pred_video)  # per-slot rollout results
            # stack (gt video, gt slots recon video, slot_0 rollout video)
            # horizontally to better compare the 3 videos
            compare_video = self._compare_videos(img, recon_combined,
                                                 rollout_combined)
            compare_results.append(compare_video)
            inst_table.add_data(i, dst._read_insts_text(i))

        results = self._pad_results(results)
        rollout_results = self._pad_results(rollout_results)
        compare_results = self._pad_results(compare_results)

        log_dict = {
            'val/video': self._convert_video(results),
            'val/rollout_video': self._convert_video(rollout_results),
            'val/compare_video': self._convert_video(compare_results),
            'val/inst_table': inst_table,
        }
        wandb.log(log_dict, step=self.it)
        torch.cuda.empty_cache()

    def _pad_results(self, results):
        max_length = max([img.shape[0] for img in results])
        # Pad each array with the last element to the maximum length
        padded_results = []
        for result in results:
            padding_length = max_length - result.shape[0]
            padded_result = np.concatenate([result, np.repeat(result[-1:], padding_length, axis=0)])
            padded_result = torch.Tensor(padded_result)
            padded_results.append(padded_result)
        return padded_results

    def _convert_video(self, video, caption=None):
        if type(video) == type([]): 
            video = torch.cat(video, dim=2)  # [T, 3, B*H, L*W], from -1. to 1.
            video = ((video + 1.) * 0.5 * 255.).detach().cpu().numpy().astype(np.uint8)
        else:
            # video : [task_num, video_step, 3, H, W]
            video_list = [v for v in video]
            task_num, video_step, C, H, W = video.shape
            video = torch.cat(video_list, dim=2) # [video_step, 3, H*task_num, W]
            video = ((video + 1.) * 0.5 * 255.).detach().cpu().numpy().astype(np.uint8)
            
        return wandb.Video(video, fps=self.vis_fps, caption=caption)


            




