import os
import wandb
import numpy as np

import torch
import torchvision.utils as vutils
import torch.optim as optim
import torch.nn.functional as F

from nerv.training import BaseMethod, CosineAnnealingWarmupRestarts
from nerv.utils.io import check_file_exist, mkdir_or_exist

from slotformer.base_slots.method import SAViMethod, STEVEMethod, \
    to_rgb_from_tensor

from .models import cosine_anneal, get_lr, gumbel_softmax, make_one_hot, \
    to_rgb_from_tensor

from language_table.environments import blocks
from language_table.environments import language_table
from language_table.environments.rewards import block2block


def build_method(**kwargs):
    params = kwargs['params']
    if params.model in ['SlotFormer', 'SingleStepSlotFormer', 'LSlotFormer', 'E2E']:
        return SlotFormerMethod(**kwargs)
    else:
        raise NotImplementedError(f'{params.model} method is not implemented.')

class SlotBaseMethod(BaseMethod):
    """Base method in this project."""

    @staticmethod
    def _pad_frame(video, target_T):
        """Pad the video to a target length at the end"""
        if video.shape[0] >= target_T:
            return video
        dup_video = torch.stack(
            [video[-1]] * (target_T - video.shape[0]), dim=0)
        return torch.cat([video, dup_video], dim=0)

    @staticmethod
    def _pause_frame(video, N=4):
        """Pause the video on the first frame by duplicating it"""
        dup_video = torch.stack([video[0]] * N, dim=0)
        return torch.cat([dup_video, video], dim=0)

    def _convert_video(self, video, caption=None):
        video = torch.cat(video, dim=2)  # [T, 3, B*H, L*W]
        video = (video * 255.).numpy().astype(np.uint8)
        return wandb.Video(video, fps=self.vis_fps, caption=caption)

    @staticmethod
    def _get_sample_idx(N, dst):
        """Load videos uniformly from the dataset."""
        dst_len = len(dst.files)  # treat each video as a sample
        N = N - 1 if dst_len % N != 0 else N
        sampled_idx = torch.arange(0, dst_len, dst_len // N)
        return sampled_idx

    @torch.no_grad()
    def validation_epoch(self, model, san_check_step=-1, sample_video=True):
        """Validate one epoch.

        We aggregate the avg of all statistics and only log once.
        """
        super().validation_epoch(model, san_check_step=san_check_step)
        if self.local_rank != 0:
            return
        # visualization after every epoch
        if sample_video:
            self._sample_video(model)

    def _configure_optimizers(self):
        """Returns an optimizer, a scheduler and its frequency (step/epoch)."""
        optimizer = super()._configure_optimizers()[0]

        lr = self.params.lr
        total_steps = self.params.max_epochs * len(self.train_loader)
        warmup_steps = self.params.warmup_steps_pct * total_steps

        scheduler = CosineAnnealingWarmupRestarts(
            optimizer,
            total_steps,
            max_lr=lr,
            min_lr=lr / 100.,
            warmup_steps=warmup_steps,
        )

        return optimizer, (scheduler, 'step')

    @property
    def vis_fps(self):
        # PHYRE
        if 'phyre' in self.params.dataset.lower():
            return 4
        # OBJ3D, CLEVRER, Physion
        else:
            return 8

class SAViMethod(SlotBaseMethod):
    """SAVi model training method."""

    def _make_video_grid(self, imgs, recon_combined, recons, masks):
        """Make a video of grid images showing slot decomposition."""
        # pause the video on the 1st frame in PHYRE
        if 'phyre' in self.params.dataset.lower():
            imgs, recon_combined, recons, masks = [
                self._pause_frame(x)
                for x in [imgs, recon_combined, recons, masks]
            ]
        # in PHYRE if the background is black, we scale the mask differently
        scale = 0. if self.params.get('reverse_color', False) else 1.
        # combine images in a way so we can display all outputs in one grid
        # output rescaled to be between 0 and 1
        out = to_rgb_from_tensor(
            torch.cat(
                [
                    imgs.unsqueeze(1),  # original images
                    recon_combined.unsqueeze(1),  # reconstructions
                    recons * masks + (1. - masks) * scale,  # each slot
                ],
                dim=1,
            ))  # [T, num_slots+2, 3, H, W]
        # stack the slot decomposition in all frames to a video
        save_video = torch.stack([
            vutils.make_grid(
                out[i].cpu(),
                nrow=out.shape[1],
                pad_value=1. - scale,
            ) for i in range(recons.shape[0])
        ])  # [T, 3, H, (num_slots+2)*W]
        return save_video

    @torch.no_grad()
    def _sample_video(self, model):
        """model is a simple nn.Module, not warpped in e.g. DataParallel."""
        model.eval()
        dst = self.val_loader.dataset
        sampled_idx = self._get_sample_idx(self.params.n_samples, dst)
        results, labels = [], []
        for i in sampled_idx:
            data_dict = dst.get_video(i.item())
            video, label = data_dict['video'].float().to(self.device), \
                data_dict.get('label', None)  # label for PHYRE
            in_dict = {'img': video[None]}
            out_dict = model(in_dict)
            out_dict = {k: v[0] for k, v in out_dict.items()}
            recon_combined, recons, masks = out_dict['post_recon_combined'], \
                out_dict['post_recons'], out_dict['post_masks']
            imgs = video.type_as(recon_combined)
            save_video = self._make_video_grid(imgs, recon_combined, recons,
                                               masks)
            results.append(save_video)
            labels.append(label)

        if all(lbl is not None for lbl in labels):
            caption = '\n'.join(
                ['Success' if lbl == 1 else 'Fail' for lbl in labels])
        else:
            caption = None
        wandb.log({'val/video': self._convert_video(results, caption=caption)},
                  step=self.it)
        torch.cuda.empty_cache()

class SlotFormerMethod(SAViMethod):

    @torch.no_grad()
    def load_ckp(self, ckp_path=None, auto_detect=True):
        """Load from checkpoint.

        Support automatic detection of existing checkpoints.
        Useful in SLURM preemption systems.
        """
        # automatically detect checkpoints
        if auto_detect and os.path.exists(self.ckp_path):
            ckp_files = os.listdir(self.ckp_path)
            ckp_files = [ckp for ckp in ckp_files if ckp.endswith('.pth')]
            if ckp_files:
                ckp_files = sorted(
                    ckp_files,
                    key=lambda x: os.path.getmtime(
                        os.path.join(self.ckp_path, x)))
                last_ckp = ckp_files[-1]
                print(f'INFO: automatically detect checkpoint {last_ckp}')
                ckp_path = os.path.join(self.ckp_path, last_ckp)

        if not ckp_path:
            return

        print(f'INFO: loading checkpoint {ckp_path}')

        check_file_exist(ckp_path)
        ckp = torch.load('PRETRAINED_SAVI_CKP_PATH')
        ckp_1 = torch.load('PRETRAINED_LSLOTFORMER_CKP_PATH')
        if ckp_files:
            self.it, self.epoch = ckp['it'], ckp['epoch']

        try:
            self.model.module.load_state_dict(ckp['state_dict'])
        except:
            # pass
            print("INFO: some keys are missing (probably rollouter), but loaded anyway")
            self.model.module.load_state_dict(ckp['state_dict'], strict=False)
            print("INFO: slot encoder loaded")
            self.model.module.load_state_dict(ckp_1['state_dict'], strict=False)
            print("INFO: slot predictor loaded")
        
        if ckp_files:
            self.optimizer.load_state_dict(ckp['opt_state_dict'])
        if self.scheduler_method:
            self.scheduler.load_state_dict(ckp['scheduler_state_dict'])
            self.scheduler_method = ckp['scheduler_method']
        if self.use_fp16:
            self.grad_scaler.load_state_dict(ckp['grad_scaler'])
        # should consider loading data sampler
        if 'rank0_train_sampler' in ckp.keys():
            print('INFO: loading train loader state')
            self.train_loader.sampler.load_state_dict(
                ckp[f'rank{self.local_rank}_train_sampler'])



    def _training_step_start(self):
        """Things to do at the beginning of every training step."""
        super()._training_step_start()

        if not hasattr(self.params, 'use_loss_decay'):
            return

        # decay the temporal weighting linearly
        if not self.params.use_loss_decay:
            self.model.module.loss_decay_factor = 1.
            return

        cur_steps = self.it
        total_steps = self.params.max_epochs * len(self.train_loader)
        decay_steps = self.params.loss_decay_pct * total_steps

        if cur_steps >= decay_steps:
            self.model.module.loss_decay_factor = 1.
            return

        # increase tau linearly from 0.01 to 1
        self.model.module.loss_decay_factor = \
            0.01 + cur_steps / decay_steps * 0.99

    def _log_train(self, out_dict):
        """Log statistics in training to wandb."""
        super()._log_train(out_dict)

        if self.local_rank != 0 or (self.epoch_it + 1) % self.print_iter != 0:
            return

        if not hasattr(self.params, 'use_loss_decay'):
            return

        # also log the loss_decay_factor
        x = self.model.module.loss_decay_factor
        wandb.log({'train/loss_decay_factor': x}, step=self.it)

    def _compare_videos(self, img, recon_combined, rollout_combined):
        """Stack 3 videos to compare them."""
        # pause the 1st frame if on PHYRE
        if 'phyre' in self.params.dataset.lower():
            img, recon_combined, rollout_combined = [
                self._pause_frame(x)
                for x in [img, recon_combined, rollout_combined]
            ]
        # pad to the length of rollout video
        T = rollout_combined.shape[0]
        img = self._pad_frame(img, T)
        recon_combined = self._pad_frame(recon_combined, T)
        out = to_rgb_from_tensor(
            torch.stack(
                [
                    img,  # original images
                    recon_combined,  # reconstructions
                    rollout_combined,  # rollouts
                ],
                dim=1,
            ))  # [T, 3, 3, H, W]
        save_video = torch.stack([
            vutils.make_grid(
                out[i].cpu(),
                nrow=out.shape[1],
                # pad white if using black background
                pad_value=1 if self.params.get('reverse_color', False) else 0,
            ) for i in range(img.shape[0])
        ])  # [T, 3, H, 3*W]
        return save_video


    def _read_video_and_slots(self, dst, idx):
        """Read the video and slots from the dataset."""
        # PHYRE
        if 'phyre' in self.params.dataset.lower():
            # read video
            data_dict = dst.get_video(idx, video_len=self.params.video_len)
            video = data_dict['video']
            # read slots
            slots = dst._read_slots(
                data_dict['data_idx'],
                video_len=self.params.video_len,
            )['slots']  # [T, N, C]
            slots = torch.from_numpy(slots).float().to(self.device)
        # OBJ3D, CLEVRER, Physion, langtable
        else:
            # read video
            video = dst.get_video(idx)['video']
            # read slots
            video_path = dst.files[idx]
            slots = dst.video_slots[os.path.basename(video_path)]  # [T, N, C]
            if self.params.frame_offset > 1:
                slots = np.ascontiguousarray(slots[::self.params.frame_offset])
            slots = torch.from_numpy(slots).float().to(self.device)

        T = min(video.shape[0], slots.shape[0])
        # 수정
        # last_idx = T-1
        # video, slots = torch.stack([video[0], video[last_idx]], dim=0), torch.stack([slots[0], slots[last_idx]], dim=0)
        # assert video.shape[0] == slots.shape[0] == 2
        # return video, slots
        # # 수정
        # video: [T, 3, H, W], slots: [T, N, C]
        return video[:T], slots[:T]

    def _read_video(self, dst, idx):
        """Read the video from the dataset."""
        # PHYRE
        if 'phyre' in self.params.dataset.lower():
            # read video
            data_dict = dst.get_video(idx, video_len=self.params.video_len)
            video = data_dict['video']
        # OBJ3D, CLEVRER, Physion, langtable
        else:
            # read video
            video = dst.get_video(idx)['video']
            # read slots
            video_path = dst.files[idx]

        T = video.shape[0]

        # video: [T, 3, H, W]
        return video[:T]
        

    @torch.no_grad()
    def validation_epoch(self, model, san_check_step=-1, sample_video=True):
        """Validate one epoch.

        We aggregate the avg of all statistics and only log once.
        """
        save_use_img_recon_loss = model.use_img_recon_loss
        model.use_img_recon_loss = sample_video  # do img_recon_loss eval
        save_loss_decay_factor = model.loss_decay_factor
        model.loss_decay_factor = 1.  # compute loss in normal scale
        BaseMethod.validation_epoch(self, model, san_check_step=san_check_step)
        model.use_img_recon_loss = save_use_img_recon_loss
        model.loss_decay_factor = save_loss_decay_factor

        if self.local_rank != 0:
            return
        # visualization after every epoch
        if sample_video:
            self._sample_video(model)

    @torch.no_grad()
    def _build_dslotformer(self):
        """Build a DSlotFormer model."""
        model = self._build_model()
        model = model.to(self.device)
        model = torch.nn.DataParallel(model, device_ids=self.device_ids)
        return model


    @torch.no_grad()
    def _sample_task(self, model):
        """Validation on RL environment."""
        model.eval()

        def decode_inst(inst):
            """Utility to decode encoded language instruction"""
            return bytes(inst[np.where(inst != 0)].tolist()).decode("utf-8") 

        pred_len = self.params.loss_dict['rollout_len']
        videos = []
        video_step = 160

        success_task = 0
        task_num = 10

        inst_table = wandb.Table(columns=["idx", "inst"])
        

        for i in range(task_num): # sample 10 tasks
            env = language_table.LanguageTable(
                block_mode=blocks.LanguageTableBlockVariants.BLOCK_4,
                reward_factory=self.reward_factory,
                seed=i
            )
            obs = env.reset()
            
            transform = BaseTransforms(self.params.resolution)
            frames = transform(obs['rgb']).unsqueeze(0) # [1, 3, H, W]
            caption = decode_inst(obs['instruction'])

            lm = SentenceTransformer('sentence-transformers/sentence-t5-base')
            inst = torch.tensor(lm.encode(caption)).unsqueeze(0) # [1, 768]


            input_frames = [frames for _ in range(self.params.input_frames)]
            dummy_frames = [torch.zeros_like(frames) for _ in range(pred_len)]
            print_frames = [frames]

            inst_table.add_data(i, caption)


            for j in range(video_step): # should succeed in 100 steps
                img = torch.stack(input_frames + dummy_frames, dim=1).to(self.device) # [1, input_frames+pred_len, 3, H, W]
                data_dict = {
                        'img': img,
                        'instruction': inst.to(self.device)
                }
                out_dict = model(data_dict)
                action = out_dict['actions'].squeeze().cpu().numpy()

                obs, r, done, _ = env.step(action)
                last_frame = transform(obs['rgb']).unsqueeze(0)
                
                input_frames.pop(0)
                input_frames.append(last_frame)
                print_frames.append(last_frame)
                assert len(input_frames) == self.params.input_frames
                if done or r == 1.:
                    print_frames += [last_frame for _ in range(video_step - j - 1)]
                    success_task += 1
                    print("INFO: task {} succeeded in {} steps!!".format(i, j))
                    break
            assert len(print_frames) == video_step + 1
            print_frames = torch.cat(print_frames, dim=0)
            videos.append(print_frames)

        log_dict = {
            'val/video': self._convert_video(videos),
            'val/inst_table': inst_table,
            'val/success_rate': success_task / task_num
        }
        
        wandb.log(log_dict, step=self.it)
        torch.cuda.empty_cache()    

    @torch.no_grad()
    def _sample_video(self, model):
        """model is a simple nn.Module, not warpped in e.g. DataParallel."""
        model.eval()
        dst = self.val_loader.dataset
        sampled_idx = self._get_sample_idx(self.params.n_samples, dst)
        results, rollout_results, compare_results = [], [], []
        inst_table = wandb.Table(columns=["idx", "inst"])
        for i in sampled_idx:
            # video, slots = self._read_video_and_slots(dst, i.item())
            video = self._read_video(dst, i.item()).to(self.device)

            slots = model._forward(video.unsqueeze(0))['post_slots']
            if slots.ndim > 3:
                slots = slots.squeeze(0)
            inst = dst._read_insts(i.item()).to(self.device).unsqueeze(0)
            T = video.shape[0]
            
            # reconstruct gt_slots as sanity-check
            # i.e. if the pre-trained weights are loaded correctly
            recon_combined, recons, masks, _ = model.decode(slots)
            img = video.type_as(recon_combined)
            save_video = self._make_video_grid(img, recon_combined, recons,
                                               masks)
            results.append(save_video)
            # rollout
            past_steps = self.params.input_frames
            past_slots = slots[:past_steps][None]  # [1, t, N, C]
            
            out_dict = model.rollout(
                past_slots, inst, T - past_steps, decode=True, with_gt=True)
            out_dict = {k: v[0] for k, v in out_dict.items()}
            rollout_combined, recons, masks = out_dict['recon_combined'], \
                out_dict['recons'], out_dict['masks']
            img = video.type_as(rollout_combined)
            pred_video = self._make_video_grid(img, rollout_combined, recons,
                                               masks)
            rollout_results.append(pred_video)  # per-slot rollout results
            # stack (gt video, gt slots recon video, slot_0 rollout video)
            # horizontally to better compare the 3 videos
            compare_video = self._compare_videos(img, recon_combined,
                                                 rollout_combined)
            compare_results.append(compare_video)
            inst_table.add_data(i, dst._read_insts_text(i))

        log_dict = {
            'val/video': self._convert_video(results),
            'val/rollout_video': self._convert_video(rollout_results),
            'val/compare_video': self._convert_video(compare_results),
            'val/inst_table': inst_table,
        }
        wandb.log(log_dict, step=self.it)
        torch.cuda.empty_cache()