import os
from typing import Optional, Tuple, Union, Dict, Any, List
import argparse
import time
import psutil
import json
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

import numpy as np
import cv2
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import transformers
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
from torchvision import transforms
from scipy.spatial.transform import Rotation


# Utilities
from envs.env_maniskill.env_evaluator import EnvEvaluator
from envs.env_maniskill.env_creator import create_maniskill_env
from envs.env_maniskill.metrics import SuccessRates

# Modules
from .trajectory_transformer import GPT2Model
from .prompt_modules import PromptEncoder, MaskedPromptModelingHead, PromptTrajectoryMatchingHead
# from .image_modules import ImageEncoder, ImageDecoder
from .state_modules import StateEncoder, MaskedStateReconstructionHead
from .action_modules import ActionEncoder, ActionDecoder, ActionDiffusion
from .cl_modules import VLAContrastiveLearningHead


class GeneralAgent(pl.LightningModule):
    def __init__(self, 
                 # List of strings, each string is the name of a feature
                 feature_list: List[str],
                 assets_path: str,
                 learning_rate: float = 1e-4,
                 learning_rate_decay: float = 0.6,
                 learning_rate_step_size: int = 3000,
                 weight_decay: float = 1e-4,
                 val_interval: int = 1e4,
                 val_online: bool = True,
                 val_offline: bool = True,
                 val_trials: int = 50,
                 config: argparse.Namespace = None,
                 ):
        super().__init__()

        ########################### Hyperparameters ############################
        self.feature_list = feature_list
        self.assets_path = assets_path
        self.learning_rate = learning_rate
        self.learning_rate_decay = learning_rate_decay
        self.learning_rate_step_size = learning_rate_step_size
        self.weight_decay = weight_decay
        self.val_interval = val_interval
        self.val_online = val_online
        self.val_offline = val_offline
        self.val_trials = val_trials
        self.prompt_config = config.prompt
        self.action_frame = "robot"
        self.hidden_size = config.n_embd
        self.config = config

        self.prompt_encoder = PromptEncoder(config)
        self.state_encoder = StateEncoder(config)
        self.action_encoder = ActionEncoder(config)
        self.action_decoder = ActionDecoder(config)
        self.action_diffusion = ActionDiffusion(config)
        if not self.config.use_action_diffusion:
            self.action_queries = nn.Parameter(torch.zeros(1, 1, config.action_dim-1, config.n_embd))
        self._build_backbone()

        ########################### PRE-TRAINING ############################
        self.mpm_head = MaskedPromptModelingHead(self.hidden_size, self.prompt_encoder.tokenizer, self.prompt_encoder.config)
        self.msr_head = MaskedStateReconstructionHead(config)
        self.ptm_head = PromptTrajectoryMatchingHead(self.hidden_size)
                #  pooling = 'max', 
                #  action_encoder = None,
                #  action_diffusion = None,
                #  config = None):

        self.cl_head = VLAContrastiveLearningHead(
                            pooling = "max", 
                            action_encoder=self.action_encoder, 
                            action_diffusion=self.action_diffusion, 
                            config = config)

        if val_online:
            self.env = create_maniskill_env(assets_path = self.assets_path)

        ########################### LOG ############################
        self.metric = SuccessRates(config)

    def _build_backbone(self):
        self.input_embed_ln = nn.LayerNorm(self.hidden_size)
        self.segment_embeddings = nn.Embedding(32, self.hidden_size)
        # The control transformer is the bare GPT model, whose attention is modified to accommodate different pretraining objectives
        GPT_config = transformers.GPT2Config(
            vocab_size = 1, # we use custom prediction heads
            n_positions = 1024,
            n_ctx = 1024,
            n_embd = self.hidden_size,
            n_layer = self.config.n_layer,
            n_head = self.config.n_head,
            n_inner = self.hidden_size * 4,
            activation_function = self.config.activation_function,
            resid_pdrop = self.config.resid_pdrop,
            attn_pdrop = self.config.attn_pdrop,
        )
        self.model = GPT2Model(GPT_config)

    def forward(self, action, t):
        """
        Each state consists of the base image, gripper image, and cropped images of some objects in the scene
        Each action consists of a mode (move/stop), and 3D pose?

        Training:
            action: (B * S, action_dim)  * 1 is the action mode
            t: (B * S,) diffusion time index
    
        Generation:
            action: (1, action_dim)  * 1 is the action mode
            t: (1,) diffusion time index

        Return:
            logits: (B, action_dim)

        Batch == [trajectory_1, trajectory_2, ... trajectory_B]
        trajectory_i ==
            +-        [[base_image, gripper_image, cropped_image_1, ..., cropped_image_m, action_mode, action_1, ... action_n],
            |          [base_image, gripper_image, cropped_image_1, ..., cropped_image_m, action_mode, action_1, ... action_n],
            seq_len (S)  ...
            |            ...
            +-         [base_image, gripper_image, cropped_image_1, ..., cropped_image_m, action_mode, action_1, ... action_n]]
                        |---------------------------- state ----------------------------| |------------- action -------------|

        """

        prompt_embed = self.encoded_action_condition["prompt_embed"]
        prompt_attn_mask = self.encoded_action_condition["prompt_attn_mask"]
        # prompt_input_ids = self.encoded_action_condition["prompt_input_ids"]
        state_embed_list = self.encoded_action_condition["state_embed_list"]
        state_attn_mask_list = self.encoded_action_condition["state_attn_mask_list"]
        action_embed = self.encoded_action_condition["action_embed"] # (B, S, action_dim, hidden)
        # last_action = self.encoded_action_condition["last_action"] # (B, 1, action_dim)
        segment_lengths = self.encoded_action_condition["segment_lengths"]
        B, S = segment_lengths["batch_size"], segment_lengths["timesteps"]

        ########################### ENCODING ############################
        if self.config.use_action_diffusion:
            noisy_action = action.view(B, -1, action.size(-1)) # (B, 1, action_dim-1)
            assert action.size(-1) == self.config.action_dim-1 # Do not diffuse the action mode
            assert noisy_action.size(1) == 1 # Only diffuse the last action

            action_noisy_embed = self.action_encoder(noisy_action) # (B, 1, action_dim, hidden)
            action_noisy_embed = self.action_diffusion.add_diffusion_timestep(action_noisy_embed, t) # (B, 1, action_dim, hidden)
            action_embed = torch.cat([action_embed[:, :-1, :, :], action_noisy_embed], dim = 1) # (B, S, action_dim, hidden)

            action_embed_list = [action_embed]
            action_attn_list = [None]

        else:
            query_tokens = self.action_queries.expand(B, S, -1, -1) # (B, S, action_dim-1, hidden)
            segment_lengths["query_length"] = query_tokens.size(2)
            # query_tokens = query_tokens.unsqueeze(1) # (B, 1, action_dim-1, hidden)
            # action_embed = torch.cat([history_action_embed, query_tokens], dim = 1) # (B, S+1, action_dim-1, hidden)

            action_embed_list = [query_tokens, action_embed]
            action_attn_list = [None, None]

        ########################### TRANSFORMER ############################
        # Pass all prompt, state, action embeddings through the control transformer
        input_embed, attn_mask, position_ids = self.assemble_input_sequence(
            prompt_embed, 
            state_embed_list + action_embed_list,
            prompt_attn_mask = prompt_attn_mask, 
            attn_mask_segments = state_attn_mask_list + action_attn_list,
            segment_lengths = segment_lengths,
        )

        hidden_states = self.model(inputs_embeds = input_embed, 
                                    attention_mask = attn_mask, 
                                    position_ids = position_ids,
                                    attention_type = "trajectory",
                                    attention_params = segment_lengths,
                                    ).last_hidden_state 
        # (B, input_len, hidden), input_len = prompt_len + S * (state_len + action_len)

        # Cache mode for diffusion
        self.mode_logits = self.action_decoder.predict_mode(hidden_states, segment_lengths) # (B, S, 2)
        decoded_action = self.action_decoder(hidden_states, segment_lengths) # (B, action_dim, vocab_size)

        if self.config.use_action_diffusion:
            return decoded_action
        else:
            if self.training:
                loss_mode, acc_mode = self.action_decoder.calculate_loss(self.mode_logits, self.encoded_action_condition["true_mode"])
                loss_action, acc_action = self.action_decoder.calculate_loss(decoded_action, self.encoded_action_condition["true_action"])
                all_losses = {"finetune/loss_mode": loss_mode, "finetune/loss": loss_action}
                all_metrics = {"finetune/acc_mode": acc_mode, "finetune/acc": acc_action}
                return all_losses, all_metrics
            else:
                return decoded_action

    def generate(self, batch : dict):
        self.encode_prompt_state(batch)
        if self.config.use_action_diffusion:
            action = self.action_diffusion.p_sample_loop(self, (1, self.config.action_dim-1))
        else:
            decoded_action = self(action = None, t = None)
            action = self.action_decoder.predict_next_action(decoded_action)
        mode = self.mode_logits[:, -1:, ...].argmax(dim = -1)
        action_with_mode = self.action_encoder.denormalize_action(mode, action)
        return action_with_mode

    def encode_prompt_state(self, 
                batch_inputs : dict,
                ):
        """
        Each state consists of the base image, gripper image, and cropped images of some objects in the scene
        Each action consists of a mode (move/stop), and 3D pose?

        batch_inputs["prompt"]: 
            {
                'language_instruction' : strings,
                ...
            }

        batch_inputs["base_camera_rgbd"], batch_inputs["hand_camera_rgbd"]: 
            (B, S, 4, H, W)

        batch_inputs["pick_object_image_bbox"], batch_inputs["place_object_image_bbox"] :
            {
                'cropped_image' : (B, S, 4, H', W')
                'bbox' : (B, S, 4)
            }
        """

        B, S = batch_inputs[self.feature_list[0].name].shape[:2]

        ########################### ENCODING ############################
        prompt_embed, prompt_attn_mask, prompt_input_ids = self.prompt_encoder(
            batch_inputs["prompt"]["language_instruction"]) # (B, prompt_len, hidden)

        state_embed_list, state_attn_mask_list, _ = self.state_encoder(batch_inputs)

        mode, action = self.action_encoder.normalize_action(batch_inputs[self.action_encoder.action_type])

        action_embed = self.action_encoder(action) # (B, S-1, 1, hidden)
        # history_action_embed = action_embed[:, :-1, :, :] # (B, S-1, 1, hidden)
        # last_action = batch_inputs[self.action_encoder.action_type][:, -1:, :] # (B, 1, action_dim)

        ########################### Encode Prompt & State ############################
        segment_lengths = {
            "batch_size" : B,
            "timesteps" : S,
            "prompt_length" : prompt_embed.size(1),
            "state_length" : sum([state_embed.size(2) for state_embed in state_embed_list]),
            "action_length" : action_embed.size(2),
        }

        self.encoded_action_condition = {
            "prompt_embed" : prompt_embed,
            "prompt_attn_mask" : prompt_attn_mask,
            "prompt_input_ids" : prompt_input_ids,
            "state_embed_list" : state_embed_list,
            "state_attn_mask_list" : state_attn_mask_list,
            "action_embed" : action_embed,
            # "last_action" : last_action,
            "segment_lengths" : segment_lengths,
            "true_action" : action, 
            "true_mode" : mode,
        }

    def pretrain(self, 
                batch_inputs : dict,
                pretrain_params : list = None,
                ):
        """
        Each state consists of the base image, gripper image, and cropped images of some objects in the scene
        Each action consists of a mode (move/stop), and 3D pose?

        batch_inputs["prompt"]: 
            {
                'language_instruction' : strings,
                ...
            }

        batch_inputs["base_camera_rgbd"], batch_inputs["hand_camera_rgbd"]: 
            (B, S, 4, H, W)

        batch_inputs["pick_object_image_bbox"], batch_inputs["place_object_image_bbox"] :
            {
                'cropped_image' : (B, S, 4, H', W')
                'bbox' : (B, S, 4)
            }

        batch_inputs["xxx_action"]: (B, S, action_dim)  * 1 is the action mode

        Return:
            logits: (B, action_dim)

        Batch == [trajectory_1, trajectory_2, ... trajectory_B]
        trajectory_i ==
            +-        [[base_image, gripper_image, cropped_image_1, ..., cropped_image_m, action_mode, action_1, ... action_n],
            |          [base_image, gripper_image, cropped_image_1, ..., cropped_image_m, action_mode, action_1, ... action_n],
            seq_len (S)  ...
            |            ...
            +-         [base_image, gripper_image, cropped_image_1, ..., cropped_image_m, action_mode, action_1, ... action_n]]
                        |---------------------------- state ----------------------------| |------------- action -------------|

        """
        assert pretrain_params is not None
        B, S = batch_inputs[self.feature_list[0].name].shape[:2]

        ########################### ENCODING ############################
        prompt_embed, prompt_attn_mask, prompt_input_ids = self.prompt_encoder(
            batch_inputs["prompt"]["language_instruction"], 
            ptm_mode = "PTM" in pretrain_params["mode"]) # (B, prompt_len, hidden)

        if "MSR" in pretrain_params["mode"]:
            batch_inputs = self.msr_head.prepare_noisy_image(batch_inputs, t_index = 1)
        state_embed_list, state_attn_mask_list, msr_target_list = self.state_encoder(batch_inputs)

        ########################### TRANSFORMER ############################
        segment_lengths = {
            "batch_size" : B,
            "batch_size_prime" : B * 2 if "PTM" in pretrain_params["mode"] else B,
            "timesteps" : S,
            "prompt_length" : prompt_embed.size(1),
            "state_length" : sum([state_embed.size(2) for state_embed in state_embed_list]),
            # "action_length" : action_embed.size(2),
        }

        all_losses = dict()
        all_metrics = dict()

        if len(pretrain_params["mode"]) == 0:
            return {}

        if not all([mode in ["MPM", "MSR", "PTM", "CL"] for mode in pretrain_params["mode"]]):
            raise ValueError('pretrain_params["mode"] should be one of MPM, MSR, PTM, CL')

        # -------------------- Prepare inputs for pretraining --------------------
        # One pretraining task can be affect by another pretraining task
        # For example, PTM can be affected by MPM because the prompts are masked

        # if "MPM" in pretrain_params["mode"]:
        #     prompt_attn_mask, prompt_loss_mask = self.mpm_head.prepare_inputs(prompt_attn_mask, pretrain_params["mpm_mask_ratio"])

        # if "MSR" in pretrain_params["mode"]:
        #     state_attn_mask_list, msr_loss_mask_list = self.msr_head.prepare_inputs(state_attn_mask_list, pretrain_params["msr_mask_ratio"])

        # if "PTM" in pretrain_params["mode"]:

        if "CL" in pretrain_params["mode"]:

            negative_prompt_embed, negative_prompt_attn_mask = \
                self.ptm_head.prepare_inputs(
                        self.prompt_encoder.get_all_prompt_embed(pretrain_params["all_prompts"], segment_lengths),
                        prompt_input_ids)

            prompt_embed, state_embed_list, action_embed_cl, prompt_attn_mask, state_attn_mask_list, segment_lengths = \
                self.cl_head.prepare_inputs(prompt_embed, state_embed_list, batch_inputs[self.action_encoder.action_type], 
                                            prompt_attn_mask, state_attn_mask_list, segment_lengths, 
                                            negative_prompt_embed = negative_prompt_embed, 
                                            negative_prompt_attn_mask = negative_prompt_attn_mask)
            
            segment_lengths["action_length"] = action_embed_cl.size(2)

        # -------------------- Pass into the Transformer --------------------
        # Pass all prompt, state, action embeddings through the control transformer
        input_embed, attn_mask, position_ids = self.assemble_input_sequence(
                                    prompt_embed, 
                                    state_embed_list + [action_embed_cl], 
                                    prompt_attn_mask = prompt_attn_mask, 
                                    attn_mask_segments = state_attn_mask_list + [None],
                                    segment_lengths = segment_lengths,
                                    )
        hidden_states = self.model(inputs_embeds = input_embed,
                                    attention_mask = attn_mask,
                                    position_ids = position_ids,
                                    attention_type = "trajectory",
                                    attention_params = segment_lengths,
                                #    attention_type = "bidirectional",
                                #    attention_params = None,
                                    ).last_hidden_state # (B*S(*2), input_len, hidden) (*2) if PTM is used

        # -------------------- Extract the losses --------------------
        # if "MPM" in pretrain_params["mode"]:
        #     loss_mpm, acc_mpm = self.mpm_head(hidden_states, prompt_input_ids, prompt_loss_mask, segment_lengths)
        #     all_losses["pretrain/loss_mpm"] = self.config.mpm_loss_weight * loss_mpm
        #     all_metrics["pretrain/acc_mpm"] = acc_mpm

        # if "MSR" in pretrain_params["mode"]:
        #     loss_msr, compare_imgs = self.msr_head(hidden_states, msr_target_list, msr_loss_mask_list, segment_lengths)
        #     all_losses["pretrain/loss_msr"] = self.config.msr_loss_weight * loss_msr

        # if "PTM" in pretrain_params["mode"]:
        #     loss_ptm, acc_ptm = self.ptm_head(hidden_states, segment_lengths)
        #     all_losses["pretrain/loss_ptm"] = self.config.ptm_loss_weight * loss_ptm
        #     all_metrics["pretrain/acc_ptm"] = acc_ptm

        if "CL" in pretrain_params["mode"]:
            loss_cl, acc_cl = self.cl_head(hidden_states, segment_lengths)
            all_losses["pretrain/loss_cl"] = self.config.cl_loss_weight * loss_cl
            all_metrics["pretrain/acc_cl"] = acc_cl

        return all_losses, all_metrics

    def assemble_input_sequence(self, 
                                prompt_embed, 
                                embed_segments, 
                                prompt_attn_mask = None, 
                                attn_mask_segments = None, 
                                segment_lengths = None):
        """
        Assemble the input sequence for the trajectory transformer

        prompt_embed: (B, prompt_len, hidden)
        embed_segments: [(B, S, seg_len_1, hidden), ..., (B, S, seg_len_n, hidden)]
            E.g.,
                embed_segments == [base_embed, grip_embed, crop_embed, action_embed]
                    base_embed: (B, S, patch_num**2, hidden)
                    crop_pick_embed: (B, S, patch_num'**2, hidden)
                    crop_place_embed: (B, S, patch_num'**2, hidden)
                    grip_embed: (B, S, patch_num**2, hidden)
                    action_embed: (B, S, action_dim, hidden)

        prompt_attn_mask: (B, prompt_len)
        attn_mask_segments: [(B, S, seg_len_1), ..., (B, S, seg_len_n)]
            E.g.,
                attn_mask_segments == [base_attn_mask, grip_attn_mask, crop_attn_mask, action_attn_mask]
                    base_attn_mask: (B, S, patch_num**2)
                    crop_pick_attn_mask: (B, S, patch_num'**2)
                    crop_place_attn_mask: (B, S, patch_num'**2)
                    grip_attn_mask: (B, S, patch_num**2)
                    action_attn_mask: (B, S, action_dim)

        Return:
            input_embed: (B*S, input_len, hidden)
            attn_mask: (B*S, input_len)
        """
        # len(embed_segments) == 5 is necessary for seg_ids to be correct
        assert attn_mask_segments is None or len(embed_segments) == len(attn_mask_segments), \
            "embed_segments and attn_mask_segments should have the same length"

        B, S = embed_segments[0].shape[:2]

        # ------------------ Embeddings ------------------
        sa_embed = torch.cat(embed_segments, dim = 2).view(B, -1, self.hidden_size)
        input_embed = torch.cat([prompt_embed, sa_embed], dim = 1)

        if prompt_attn_mask is None:
            prompt_attn_mask = torch.ones((B, prompt_embed.size(1)), device = prompt_embed.device)

        prompt_seg_ids = torch.zeros((B, prompt_embed.size(1)),
                                        dtype = torch.long,
                                        device = prompt_embed.device)

        sa_attn_mask_list = []
        sa_seg_id_list = []
        for i in range(len(embed_segments)):
            if attn_mask_segments is None or attn_mask_segments[i] is None:
                sa_attn_mask_list.append(torch.ones((B, S, embed_segments[i].size(2)),
                                                device = embed_segments[i].device))
            else:
                sa_attn_mask_list.append(attn_mask_segments[i])

            # Add seg_ids embeddings to input_embed
            # 0 for prompt, 
            # 1 for base image, 2 for cropped pick image, 3 for cropped pick image, 4 for gripper image, 
            # 5 for actions
            sa_seg_id_list.append(torch.ones((B, S, embed_segments[i].size(2)),
                                        dtype = torch.long,
                                        device = embed_segments[i].device) * (i+1))

        sa_attn_mask = torch.cat(sa_attn_mask_list, dim = 2).view(B, -1)
        attn_mask = torch.cat([prompt_attn_mask, sa_attn_mask], dim = 1)

        sa_seg_ids = torch.cat(sa_seg_id_list, dim = 2).view(B, -1)
        seg_ids = torch.cat([prompt_seg_ids, sa_seg_ids], dim = 1)
        segment_embed = self.segment_embeddings(seg_ids)

        # Set the input_embed to 0 where attn_mask is 0
        input_embed = input_embed * attn_mask.unsqueeze(-1)
        input_embed = input_embed + segment_embed
        input_embed = self.input_embed_ln(input_embed)

        position_ids_list = [0,] * prompt_embed.size(1)
        if "query_length" in segment_lengths:
            state_action_len = segment_lengths["state_length"] + \
                               segment_lengths["action_length"] + \
                               segment_lengths["query_length"]
        else:
            state_action_len = segment_lengths["state_length"] + \
                               segment_lengths["action_length"]
        for i in range(S):
            position_ids_list += [i+1,] * state_action_len
        position_ids = torch.tensor(position_ids_list, dtype = torch.long, device = input_embed.device)
        position_ids = position_ids.unsqueeze(0).expand(B, -1)

        return input_embed, attn_mask, position_ids

    # ========================= Training =========================
    def training_step(self, batch, batch_idx):
        """
        Note: image color channels in RGB order
        """

        all_losses = dict()
        all_metrics = dict()

        start_time = time.time()

        run_pretrain, run_finetune, pretrain_params = self.setup_pretrain(batch_idx)

        if run_pretrain:
            all_losses_pretrain, all_metrics_pretrain = self.pretrain(batch, pretrain_params = pretrain_params)
            all_losses = {**all_losses, **all_losses_pretrain}
            all_metrics = {**all_metrics, **all_metrics_pretrain}

        if run_finetune:

            self.encode_prompt_state(batch)

            if self.config.use_action_diffusion:
                mode, action = self.action_encoder.normalize_action(batch[self.action_encoder.action_type])
                mode = mode[:, -1, :]
                action = action[:, -1, :]

                loss_diffusion, acc_diffusion = self.action_diffusion.training_losses(self, action)
                loss_diffusion = loss_diffusion.mean()
                loss_mode, acc_mode = self.action_decoder.calculate_loss(self.mode_logits, mode)
                all_losses_finetune = {"finetune/loss": loss_diffusion,
                                    "finetune/loss_mode": loss_mode}
                all_metrics_finetune = {"finetune/acc": acc_diffusion,
                                        "finetune/acc_mode": acc_mode}
            else:
                all_losses_finetune, all_metrics_finetune = self(action = None, t = None)

            all_losses = {**all_losses, **all_losses_finetune}
            all_metrics = {**all_metrics, **all_metrics_finetune}

        # ------------------ Aggregate ------------------
        all_losses["loss_total"] = sum(all_losses.values())

        if "pretrain/compare_imgs" in all_metrics:
            compare_imgs = all_metrics.pop("pretrain/compare_imgs")
            denoised_img = self.msr_head.diffusion.p_sample(compare_imgs["target_noise"], compare_imgs["noisy_image"], t_index = 1)
            predicted_img = self.msr_head.diffusion.p_sample(compare_imgs["pred_noise"], compare_imgs["noisy_image"], t_index = 1)
            compare_imgs["denoised_img"] = denoised_img
            compare_imgs["predicted_img"] = predicted_img
            self.msr_head.log_compare_images(compare_imgs)

        self.log_dict(all_losses, prog_bar = True, logger = True)
        self.log_dict(all_metrics, logger = True)
        self.log("train/train_step_time", time.time() - start_time, on_epoch = True, sync_dist = True)
        self.log("train/memory_usage", psutil.virtual_memory().percent, on_epoch = True, sync_dist = True)

        return all_losses["loss_total"]

    def setup_pretrain(self, batch_idx):

        # ------------------ Pretrain schedule ------------------
        current_epoch = self.global_step // (self.config.val_interval / self.config.val_interval_epoch)
        if self.config.pretrain_mode == "pretrain":
            if current_epoch < self.config.num_pretrain_epochs:
                run_pretrain = True
                run_finetune = False
            else:
                run_pretrain = False
                run_finetune = True
        elif self.config.pretrain_mode == "mix":
            if current_epoch < self.config.num_pretrain_epochs:
                run_pretrain = True
                run_finetune = True
            else:
                run_pretrain = False
                run_finetune = True
        elif self.config.pretrain_mode == "interleave":
            if current_epoch < self.config.num_pretrain_epochs * 2:
                if current_epoch % 2 == 0:
                    run_pretrain = True
                    run_finetune = False
                else:
                    run_pretrain = False
                    run_finetune = True
            else:
                run_pretrain = False
                run_finetune = True
        else:
            raise NotImplementedError

        # ------------------ Pretrain parameters ------------------
        pretrain_params = {
            "mode" : [],
            # "mpm_mask_ratio" : ...
            # "msr_mask_ratio" : ...
            # "all_prompts" : ...
        }

        if run_pretrain:
            if self.config.mpm_loss_weight > 0:
                pretrain_params["mode"].append("MPM")
                pretrain_params["mpm_mask_ratio"] = self.config.mpm_mask_ratio
                if batch_idx == 0:
                    logger.info("MPM loss is added!")

            if self.config.msr_loss_weight > 0:
                # if self.config.image_encoder != "vit-patches":
                #     pretrain_params["msr_mask_ratio"] = 0.0
                pretrain_params["mode"].append("MSR")
                pretrain_params["msr_mask_ratio"] = self.config.msr_mask_ratio
                if batch_idx == 0:
                    logger.info(f"MSR loss is added for {self.config.image_encoder}!")

            if self.config.ptm_loss_weight > 0:
                if hasattr(self.config, "all_prompts") and len(self.config.all_prompts) > 1:
                    pretrain_params["mode"].append("PTM")
                    pretrain_params["all_prompts"] = self.config.all_prompts
                    if batch_idx == 0:
                        logger.info("PTM loss is added!")
                else:
                    if batch_idx == 0:
                        logger.warning("PTM loss is not added because there is no sufficient prompts!")

            if self.config.cl_loss_weight > 0:
                pretrain_params["mode"].append("CL")
                pretrain_params["all_prompts"] = self.config.all_prompts
                if batch_idx == 0:
                    logger.info("CL loss is added!")

        return run_pretrain, run_finetune, pretrain_params

    def optimizer_step(self, 
                       epoch, 
                       batch_idx, 
                       optimizer, 
                       optimizer_idx, 
                       optimizer_closure, 
                       on_tpu, 
                       using_native_amp, 
                       using_lbfgs):

        optimizer.step(closure=optimizer_closure)

        if self.config.warmup and \
            (self.trainer.global_step + 1) < self.config.warmup_steps:
            lr_scale = min(1.0, float(self.trainer.global_step + 1) / float(self.config.warmup_steps))
            for pg in optimizer.param_groups:
                pg['lr'] = lr_scale * self.learning_rate

    def configure_optimizers(self):
        """
        Configure the optimizer and scheduler
        """
        for name, params in self.prompt_encoder.model.named_parameters():
            params.requires_grad = False

        training_params = filter(lambda p: p.requires_grad, self.parameters())
        optimizer = torch.optim.AdamW(
            training_params, 
            lr = self.learning_rate, 
            weight_decay = self.weight_decay
            )
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, 
            step_size = self.learning_rate_step_size, 
            gamma = self.learning_rate_decay,
            )

        return {
                "optimizer" : optimizer,
                "lr_scheduler" : {
                    "scheduler" : lr_scheduler,
                    "interval" : "step",
                }
            }

    # ========================= Validation =========================
    def validation_step(self, batch, batch_idx):
        start_time = time.time()
        if self.val_offline:
            self.offline_validation_step(batch, batch_idx)
        if self.val_online:
            self.online_validation_step(batch, batch_idx)

        self.log("val/val_step_time", time.time() - start_time, on_epoch = True, sync_dist = True)
        self.log("val/memory_usage", psutil.virtual_memory().percent, on_epoch = True, sync_dist = True)

    def offline_validation_step(self, batch, batch_idx):
        with torch.no_grad():
            self.train()
            run_pretrain, run_finetune, pretrain_params = self.setup_pretrain(batch_idx)

            self.encode_prompt_state(batch)

            # action = batch[self.action_encoder.action_type]
            # mode = action[:, -1, 0] # (B,) Extract the last action, extract the action mode
            # mode = self.action_encoder.normalize_mode(mode)
            # action = action[:, -1, 1:] # (B, action_dim-1) Extract the last action, remove the action mode
            # action = self.action_encoder.normalize_action(action)

            # t = (torch.randint(low=0, high=(self.action_diffusion.diffusion.num_timesteps), size=(action.shape[0],))).to(action.device)
            # loss_diffusion, acc_diffusion = self.action_diffusion.diffusion.training_losses(self, action, t)
            # loss_diffusion = loss_diffusion.mean()
            # loss_mode, acc_mode = self.action_decoder.calculate_loss(self.mode_logits, mode)

            # self.log("val/loss", loss_diffusion.detach().cpu().item(), on_epoch = True, sync_dist = True)
            # self.log("val/acc", acc_diffusion, on_epoch = True, sync_dist = True)
            # self.log("val/loss_mode", loss_mode.detach().cpu().item(), on_epoch = True, sync_dist = True)
            # self.log("val/acc_mode", acc_mode, on_epoch = True, sync_dist = True)

            if self.config.use_action_diffusion:
                mode, action = self.action_encoder.normalize_action(batch[self.action_encoder.action_type])
                mode = mode[:, -1, :]
                action = action[:, -1, :]

                loss_diffusion, acc_diffusion = self.action_diffusion.training_losses(self, action)
                loss_diffusion = loss_diffusion.mean()
                loss_mode, acc_mode = self.action_decoder.calculate_loss(self.mode_logits, mode)
                self.log("val/loss", loss_diffusion.detach().cpu().item(), on_epoch = True, sync_dist = True)
                self.log("val/acc", acc_diffusion, on_epoch = True, sync_dist = True)
                self.log("val/loss_mode", loss_mode.detach().cpu().item(), on_epoch = True, sync_dist = True)
                self.log("val/acc_mode", acc_mode, on_epoch = True, sync_dist = True)
            else:
                all_losses, all_metrics = self(action = None, t = None)

                self.log("val/loss", all_losses["finetune/loss"].detach().cpu().item(), on_epoch = True, sync_dist = True)
                self.log("val/acc", all_metrics["finetune/acc"], on_epoch = True, sync_dist = True)
                self.log("val/loss_mode", all_losses["finetune/loss_mode"].detach().cpu().item(), on_epoch = True, sync_dist = True)
                self.log("val/acc_mode", all_metrics["finetune/acc_mode"], on_epoch = True, sync_dist = True)

            # if self.config.msr_auxiliary_loss_weight > 0:
            #     self.log("val/loss_msr_aux", loss_val["finetune/loss_msr_aux"].detach().cpu().item(), on_epoch = True, sync_dist = True)

            # ------------------ Pretrain ------------------
            if len(pretrain_params["mode"]) > 0:
                all_losses_pretrain, all_metrics_pretrain = self.pretrain(batch, pretrain_params = pretrain_params)

                # if "pretrain/compare_imgs" in all_metrics_pretrain:
                #     all_metrics_pretrain.pop("pretrain/compare_imgs")

                # if "MPM" in pretrain_params["mode"]:
                #     self.log("val/loss_mpm", all_losses_pretrain["pretrain/loss_mpm"].detach().cpu().item(), on_epoch = True, sync_dist = True)
                #     self.log("val/acc_mpm", all_metrics_pretrain["pretrain/acc_mpm"], on_epoch = True, sync_dist = True)
                # if "MSR" in pretrain_params["mode"]:
                #     self.log("val/loss_msr", all_losses_pretrain["pretrain/loss_msr"].detach().cpu().item(), on_epoch = True, sync_dist = True)
                # if "PTM" in pretrain_params["mode"]:
                #     self.log("val/loss_ptm", all_losses_pretrain["pretrain/loss_ptm"].detach().cpu().item(), on_epoch = True, sync_dist = True)
                #     self.log("val/acc_ptm", all_metrics_pretrain["pretrain/acc_ptm"], on_epoch = True, sync_dist = True)
                if "CL" in pretrain_params["mode"]:
                    self.log("val/loss_cl", all_losses_pretrain["pretrain/loss_cl"].detach().cpu().item(), on_epoch = True, sync_dist = True)
                    self.log("val/acc_cl", all_metrics_pretrain["pretrain/acc_cl"], on_epoch = True, sync_dist = True)

    def online_validation_step(self, batch, batch_idx):
        batch_prompt = batch["prompt"]
        evaluator = EnvEvaluator(policy=self, config = self.config)

        num_trials = self.val_trials
        if self.config.mode == "train" and int(self.global_step / self.config.val_interval) == 0:
            num_trials = 1

        for i in range(len(batch_prompt["language_instruction"])):
            prompt = {key : batch_prompt[key][i] for key in batch_prompt}

            self.env = create_maniskill_env(
                prompt["language_instruction"], self.assets_path, base_env = self.env,
            )

            is_successful, steps = evaluator.evaluate(self.env, prompt, num_trials, max_steps=100)

            log_dict = {
                "env" : self.env.task_info.task_type,
                "task" : prompt["language_instruction"],
                "success" : np.mean(is_successful),
                "success_count" : np.sum(is_successful),
                "trials": len(is_successful),
                "success_list": is_successful,
                "steps": steps,
            }
            self.metric.update(log_dict)
            # self.save_eval_results(log_dict)
            print("val log_dict", log_dict)

        self.log("val/success", log_dict["success"], on_epoch = True, sync_dist = True)
        self.log("val/mean_steps", np.mean(steps), on_epoch = True, sync_dist = True)

    def validation_epoch_end(self, validation_step_out):
        if self.val_online:
            results = self.metric.compute()
            tb = self.logger.experiment
            for res in results:
                if res.startswith("unseen"):
                    tb.add_scalar("unseen/"+res, results[res], global_step = int(self.global_step))
                else:
                    tb.add_scalar("seen/"+res, results[res], global_step = int(self.global_step))
                
                tb.add_scalar("epoch"+res, results[res], global_step = int(self.global_step / self.config.val_interval))

            self.metric.reset()
            if self.global_rank == 0:
                logger.info(f"### Epoch: {int(self.global_step / self.config.val_interval)} Step: {self.global_step}")
                logger.info(f"===save_dir==={self.logger.save_dir}/checkpoint/{self.config.labels}")
                for key in results:
                    print(f"{key}: {results[key]}")