import os
import wandb
import numpy as np

import torch
import torchvision.utils as vutils

from nerv.training import BaseMethod, CosineAnnealingWarmupRestarts

from ..base_slots.models import to_rgb_from_tensor

from ..base_slots.datasets.utils import BaseTransforms
from nerv.utils.io import check_file_exist
from sentence_transformers import SentenceTransformer

from language_table.environments import blocks
from language_table.environments import language_table
from language_table.environments.rewards import block2block

from gym.vector import AsyncVectorEnv


import time

def normalize_act(action, min_val=-0.1, max_val=0.1):
    action = 2 * (action - min_val) / (max_val - min_val) - 1.
    return action

def denormalize_act(action, min_val=-0.1, max_val=0.1):
    action = (action + 1) / 2 * (max_val - min_val) + min_val
    return action
    
def build_method(**kwargs):
    params = kwargs['params']
    if params.model in ['RoboticsTransformer', 'RoboticsLSlotFormer', 'RoboticsLSlotFormerMLPDec', 'RoboticsLSlotFormerMLPDec1Slot', 'RoboticsLSlotFormer1Slot', 'RoboticsLSlotFormerMLPDec1SlotLang', 'RoboticsLSlotFormerSlotFormer']:
        return RoboticsTransformerMethod(**kwargs)
    else:
        raise NotImplementedError(f'{params.model} method is not implemented.')

class ImgBaseMethod(BaseMethod):
    """Base method in this project."""

    @staticmethod
    def _pad_frame(video, target_T):
        """Pad the video to a target length at the end"""
        if video.shape[0] >= target_T:
            return video
        dup_video = torch.stack(
            [video[-1]] * (target_T - video.shape[0]), dim=0)
        return torch.cat([video, dup_video], dim=0)

    @staticmethod
    def _pause_frame(video, N=4):
        """Pause the video on the first frame by duplicating it"""
        dup_video = torch.stack([video[0]] * N, dim=0)
        return torch.cat([dup_video, video], dim=0)

    def _convert_video(self, video, caption=None):
        if type(video) == type([]): 
            video = torch.cat(video, dim=2)  # [T, 3, B*H, L*W]
            video = (video * 255.).numpy().astype(np.uint8)
        else:
            # video : [task_num, video_step, 3, H, W]
            task_num, video_step, C, H, W = video.shape
            video = video.reshape(video_step, C, H * task_num, W)
        return wandb.Video(video, fps=self.vis_fps, caption=caption)



    @staticmethod
    def _get_sample_idx(N, dst):
        """Load videos uniformly from the dataset."""
        dst_len = len(dst.files)  # treat each video as a sample
        N = N - 1 if dst_len % N != 0 else N
        sampled_idx = torch.arange(0, dst_len, dst_len // N)
        return sampled_idx

    @torch.no_grad()
    def validation_epoch(self, model, san_check_step=-1, sample_video=True):
        """Validate one epoch.

        We aggregate the avg of all statistics and only log once.
        """
        super().validation_epoch(model, san_check_step=san_check_step)
        if self.local_rank != 0:
            return
        # visualization after every epoch
        if sample_video:
            self._sample_video(model)

    def _configure_optimizers(self):
        """Returns an optimizer, a scheduler and its frequency (step/epoch)."""
        optimizer = super()._configure_optimizers()[0]

        lr = self.params.lr
        total_steps = self.params.max_epochs * len(self.train_loader)
        warmup_steps = self.params.warmup_steps_pct * total_steps

        scheduler = CosineAnnealingWarmupRestarts(
            optimizer,
            total_steps,
            max_lr=lr,
            min_lr=lr / 100.,
            warmup_steps=warmup_steps,
        )

        return optimizer, (scheduler, 'step')

    @property
    def vis_fps(self):
        # PHYRE
        if 'phyre' in self.params.dataset.lower():
            return 4
        # OBJ3D, CLEVRER, Physion
        else:
            return 8


class RoboticsTransformerMethod(ImgBaseMethod):

    @torch.no_grad()
    def load_ckp(self, ckp_path=None, auto_detect=True):
        """Load from checkpoint.

        Support automatic detection of existing checkpoints.
        Useful in SLURM preemption systems.
        """
        # automatically detect checkpoints
        if auto_detect and os.path.exists(self.ckp_path):
            ckp_files = os.listdir(self.ckp_path)
            ckp_files = [ckp for ckp in ckp_files if ckp.endswith('.pth')]
            if ckp_files:
                ckp_files = sorted(
                    ckp_files,
                    key=lambda x: os.path.getmtime(
                        os.path.join(self.ckp_path, x)))
                last_ckp = ckp_files[-1]
                print(f'INFO: automatically detect checkpoint {last_ckp}')
                ckp_path = os.path.join(self.ckp_path, last_ckp)

        if not ckp_path:
            return

        print(f'INFO: loading checkpoint {ckp_path}')

        check_file_exist(ckp_path)
        # ckp = torch.load(ckp_path, map_location='cpu')
        if ckp_files:
            ckp = torch.load(ckp_path, map_location='cpu')
        else:
            ckp = torch.load('PRETRAINED_SAVI_CKP_PATH') # enter your pretrained SAVi path
        ckp_1 = torch.load('PRETRAINED_LSLOTFORMER_CKP_PATH') # enter your pretrained LSlotFormer path
        
        if ckp_files:
            self.it, self.epoch = ckp['it'], ckp['epoch']

        try:
            self.model.module.load_state_dict(ckp['state_dict'])
        except:
            print("INFO: some keys are missing (probably rollouter), but loaded anyway")
            
            self.model.module.encoder.load_state_dict(ckp['state_dict'], strict=False)
            print("INFO: slot encoder loaded")
            self.model.module.encoder.load_state_dict(ckp_1['state_dict'], strict=False)
            print("INFO: slot predictor loaded")
        
        if ckp_files:
            self.optimizer.load_state_dict(ckp['opt_state_dict'])
        if self.scheduler_method and ckp_files:
            self.scheduler.load_state_dict(ckp['scheduler_state_dict'])
            self.scheduler_method = ckp['scheduler_method']
        if self.use_fp16 and ckp_files:
            self.grad_scaler.load_state_dict(ckp['grad_scaler'])
        # should consider loading data sampler
        if 'rank0_train_sampler' in ckp.keys():
            print('INFO: loading train loader state')
            self.train_loader.sampler.load_state_dict(
                ckp[f'rank{self.local_rank}_train_sampler'])

    def _make_action_slot_grid(self, recon_combined, recons, masks):
        """Make a video of grid images showing slot decomposition."""
        scale = 0.
        out = to_rgb_from_tensor(
            torch.cat(
                [
                    recon_combined.unsqueeze(1),  # reconstructions
                    recons * masks + (1. - masks) * scale,  # each slot
                ],
                dim=1,
            ))  # [T, num_slots+2, 3, H, W]
        # stack the slot decomposition in all frames to a video
        save_video = torch.stack([
            vutils.make_grid(
                out[i].cpu(),
                nrow=out.shape[1],
                pad_value=1. - scale,
            ) for i in range(recons.shape[0])
        ])  # [T, 3, H, (num_slots+2)*W]
        return save_video


    def _training_step_start(self):
        """Things to do at the beginning of every training step."""
        super()._training_step_start()

        if not hasattr(self.params, 'use_loss_decay'):
            return

        # decay the temporal weighting linearly
        if not self.params.use_loss_decay:
            self.model.module.loss_decay_factor = 1.
            return

        cur_steps = self.it
        total_steps = self.params.max_epochs * len(self.train_loader)
        decay_steps = self.params.loss_decay_pct * total_steps

        if cur_steps >= decay_steps:
            self.model.module.loss_decay_factor = 1.
            return

        # increase tau linearly from 0.01 to 1
        self.model.module.loss_decay_factor = \
            0.01 + cur_steps / decay_steps * 0.99

        

    def _log_train(self, out_dict):
        """Log statistics in training to wandb."""
        super()._log_train(out_dict)

        if self.local_rank != 0 or (self.epoch_it + 1) % self.print_iter != 0:
            return

        if not hasattr(self.params, 'use_loss_decay'):
            return

        # also log the loss_decay_factor
        x = self.model.module.loss_decay_factor
        wandb.log({'train/loss_decay_factor': x}, step=self.it)

    def _compare_videos(self, img, recon_combined, rollout_combined):
        """Stack 3 videos to compare them."""
        # pad to the length of rollout video
        T = rollout_combined.shape[0]
        img = self._pad_frame(img, T)
        recon_combined = self._pad_frame(recon_combined, T)
        out = to_rgb_from_tensor(
            torch.stack(
                [
                    img,  # original images
                    recon_combined,  # reconstructions
                    rollout_combined,  # rollouts
                ],
                dim=1,
            ))  # [T, 3, 3, H, W]
        save_video = torch.stack([
            vutils.make_grid(
                out[i].cpu(),
                nrow=out.shape[1],
                # pad white if using black background
                pad_value=1 if self.params.get('reverse_color', False) else 0,
            ) for i in range(img.shape[0])
        ])  # [T, 3, H, 3*W]
        return save_video

    def _read_video(self, dst, idx):
        """Read the video and slots from the dataset."""
        # PHYRE
        if 'phyre' in self.params.dataset.lower():
            # read video
            data_dict = dst.get_video(idx, video_len=self.params.video_len)
            video = data_dict['video']
            # read slots
            slots = dst._read_slots(
                data_dict['data_idx'],
                video_len=self.params.video_len,
            )['slots']  # [T, N, C]
            slots = torch.from_numpy(slots).float().to(self.device)
        # OBJ3D, CLEVRER, Physion, langtable
        else:
            # read video
            video = dst.get_video(idx)['video']
            # read slots
            video_path = dst.files[idx]
            # slots = dst.video_slots[os.path.basename(video_path)]  # [T, N, C]
            # if self.params.frame_offset > 1:
                # slots = np.ascontiguousarray(slots[::self.params.frame_offset])
            # slots = torch.from_numpy(slots).float().to(self.device)
        # T = min(video.shape[0], slots.shape[0])
        T = video.shape[0]
        # video: [T, 3, H, W]
        return video[:T]

    @torch.no_grad()
    def validation_epoch(self, model, san_check_step=-1, sample_video=True):
        """Validate one epoch.

        We aggregate the avg of all statistics and only log once.
        """
        save_loss_decay_factor = model.loss_decay_factor
        model.loss_decay_factor = 1.  # compute loss in normal scale
        BaseMethod.validation_epoch(self, model, san_check_step=san_check_step)
        model.loss_decay_factor = save_loss_decay_factor

        """Setting RL env for evaluation."""
        self.reward_factory = block2block.BlockToBlockReward # CHANGEME: change to another reward.

        if self.local_rank != 0:
            return
        # # visualization after every epoch
        if self.params.act_dec_dict['sample_task']:
            self._sample_task(model)

    
    def _convert_video(self, video, caption=None):
        if type(video) == type([]): 
            video = torch.cat(video, dim=2)  # [T, 3, B*H, L*W], from -1. to 1.
            video = ((video + 1.) * 0.5 * 255.).detach().cpu().numpy().astype(np.uint8)
        else:
            # video : [task_num, video_step, 3, H, W]
            video_list = [v for v in video]
            task_num, video_step, C, H, W = video.shape
            video = torch.cat(video_list, dim=2) # [video_step, 3, H*task_num, W]
            video = ((video + 1.) * 0.5 * 255.).detach().cpu().numpy().astype(np.uint8)
            
        return wandb.Video(video, fps=self.vis_fps, caption=caption)

                
    @torch.no_grad()
    def _sample_task(self, model):
        """Validation on RL environment."""
        model.eval()

        def decode_inst(inst):
            """Utility to decode encoded language instruction"""
            return bytes(inst[np.where(inst != 0)].tolist()).decode("utf-8") 


        pred_len = self.params.loss_dict['rollout_len']
        # maximum steps to run the task
        video_step = self.params.act_dec_dict['task_steps']

        # number of tasks to sample
        task_num = self.params.act_dec_dict['num_sample_tasks']

        success_task_dict = {i:9999 for i in range(task_num)}
        success_task = []
 
        inst_table = wandb.Table(columns=["idx", "inst"])
        batch_size = self.params.act_dec_dict['task_batch_size'] 
        batch_num = task_num // batch_size
        videos = []
        for b in range(batch_num):
            torch.cuda.empty_cache()
            envs = AsyncVectorEnv([
                    lambda i=i: language_table.LanguageTable(
                    block_mode=blocks.LanguageTableBlockVariants.BLOCK_4,
                    reward_factory=self.reward_factory,
                    seed=i)
                    for i in range((b+0) * batch_size, (b+1) * batch_size) 
                    ])
            
            obses = envs.reset()

            transform = BaseTransforms(self.params.resolution)
            frames = [transform(rgb) for rgb in obses['rgb']]
            frames = torch.stack(frames, dim=0) # [task_num, 3, H, W]

            captions = [decode_inst(inst) for inst in obses['instruction']]
            lm = SentenceTransformer('sentence-transformers/sentence-t5-base')
            inst = torch.stack([torch.Tensor(lm.encode(caption)) for caption in captions], dim=0) # [task_num, 768]

            input_frames = [frames for _ in range(self.params.input_frames)]
            input_frames = torch.stack(input_frames, dim=1) # [task_num, input_frames, 3, H, W]
            dummy_frames = [torch.zeros_like(frames) for _ in range(pred_len)]
            dummy_frames = torch.stack(dummy_frames, dim=1) # [task_num, pred_len, 3, H, W]
            print_frames = [frames]


            for i, caption in enumerate(captions):
                inst_table.add_data(i, caption)
            
            times = []
            for i in range(video_step):
                t1 = time.time()
                img = torch.cat([input_frames, dummy_frames], dim=1) # [task_num, T, 3, H, W]
                data_dict = {
                    'img': img.to(self.device),
                    'instruction': inst.to(self.device),
                }
                out_dict = model(data_dict) 
                actions = denormalize_act(out_dict['actions'].cpu().numpy()) # [task_num, D_a]
                obses, r, done, _ = envs.step(actions)
                last_frames = torch.stack(list(map(lambda rgb: transform(rgb).unsqueeze(0), obses['rgb'])), dim=0) # [task_num, 1, 3, H, W]
                # print("INFO: last_frames.shape: ", last_frames.shape)
                
                input_frames = torch.cat([input_frames[:, 1:], last_frames], dim=1)
                print_frames.append(last_frames.squeeze())
                assert input_frames.shape[1] == self.params.input_frames
                if done.any() or (r == 1.).any():
                    success_indices = list(np.where(done == True)[0] + b * batch_size)
                    success_task += success_indices
                    for t in success_indices:
                        success_task_dict[t] = i if i < success_task_dict[t] else success_task_dict[t]
                        print(f"INFO: task {t} success on step {i}!!!!")
                t2 = time.time()
            times.append(t2 - t1)
            print("INFO: time per step: ", np.mean(np.array(times)))
            print_frames = torch.stack(print_frames, dim=1) # [task_num, video_step, 3, H, W]
            videos.append(print_frames)
        
        videos = torch.cat(videos, dim=0)
        for t, i in success_task_dict.items():
            if i < video_step:
                videos[t, i:] = torch.randn_like(videos[t, i:])

        success_task = list(set(success_task))

        print("INFO: success rate is ", len(success_task) / task_num)

        success_1 = []
        success_2 = []
        success_3 = []
        success_4 = []

        for task_id in success_task:
            if 0 <= int(task_id) < task_num // 4:
                success_1.append(task_id)
            elif task_num // 4 <= int(task_id) < task_num // 2:
                success_2.append(task_id)
            elif task_num // 2 <= int(task_id) < task_num // 4 * 3:
                success_3.append(task_id)
            elif task_num // 4 * 3 <= int(task_id) < task_num:
                success_4.append(task_id)
        
        success_rate_1 = len(success_1) / (task_num // 4)
        success_rate_2 = len(success_2) / (task_num // 4)
        success_rate_3 = len(success_3) / (task_num // 4)
        success_rate_4 = len(success_4) / (task_num // 4)
            


        log_dict = {
            'val/video': self._convert_video(videos) if self.params.act_dec_dict['upload_video'] != 'none' else None,
            'val/inst_table': inst_table,
            'val/success_rate': len(success_task) / task_num,
            'val/success_rate_1': success_rate_1,
            'val/success_rate_2': success_rate_2,
            'val/success_rate_3': success_rate_3,
            'val/success_rate_4': success_rate_4,
        }
        
        wandb.log(log_dict, step=self.it)
        torch.cuda.empty_cache()
            

        

                

   