import math
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers.mlp import SwiGLU
from timm.layers.norm import RmsNorm

from .layers.quantizers.vq1 import VectorQuantizer2 as VectorQuantizer
from .layers.quantizers.fsq import FSQ
from .layers.quantizers.ohq import OneHotQuantizer
from .models.modeling_mixin import LatenActionMixin

from .dino import Dinov2Embedder
from .model import (
    BaseModelConfig,
    BaseModelDecoder,
    BaseModelEncoder,
    BaseModelEncoderVAE,
    ModelForwardOutput,
    BasePretrainedModel,
    mask2len,
    select_last_k,
)

import os
from einops import rearrange, repeat
from .acceleration.checkpoint import set_grad_checkpoint
from .registry import MODELS, SCHEDULERS, build_module
from .utils.misc import (
    format_numel_str,
    get_model_numel,
)
from .utils.ckpt_utils import (
    load,
    model_gathering,
    model_sharding,
    record_model_param_shape,
)
from .utils.train_utils import MaskGenerator, update_ema, update_ema_accelerator

import mediapy as media
from .utils.inference_utils import prepare_multi_resolution_info
import dataclasses
from copy import deepcopy

from types import SimpleNamespace
import collections.abc
import torch.distributions as D

def dict_to_namespace(d):
    if isinstance(d, dict):
        return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
    elif isinstance(d, list):
        return [dict_to_namespace(elem) for elem in d]
    else:
        return d
    
def namespace_to_dict(obj):
    if isinstance(obj, SimpleNamespace):
        return {k: namespace_to_dict(v) for k, v in obj.__dict__.items()}
    elif isinstance(obj, collections.abc.Mapping):
        return {k: namespace_to_dict(v) for k, v in obj.items()}
    elif isinstance(obj, collections.abc.Iterable) and not isinstance(obj, (str, bytes)):
        return [namespace_to_dict(elem) for elem in obj]
    else:
        return obj
    
class WorldModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.vae_max_compress_rate = self.args.vae_max_compress_rate
        self.train_wm_seq_length = self.args.train_wm_seq_length
        self.sample_length = self.vae_max_compress_rate * self.train_wm_seq_length

        self.vae = build_module(namespace_to_dict(self.args.vae), MODELS)
        self.vae.requires_grad_(False)
        self.vae.eval()
        self.latent_size = self.vae.get_latent_size((self.args.num_frames, *(self.args.input_height, self.args.input_width)))

        self.transformer = build_module(
            namespace_to_dict(self.args.transformer),
            MODELS,
            input_size=self.latent_size,
            in_channels=self.vae.out_channels,
            caption_channels=self.args.text_encoder_output_dim,
            model_max_length=self.args.text_encoder_model_max_length,
            enable_sequence_parallelism=False,
            latent_action_in_channel=args.latent_action_dim,
            latent_action_num_tokens=args.latent_action_num_tokens,
        )
        if self.args.load_transformer_pretrained_weights:
            missing, unexpected = self.transformer.load_state_dict(
                torch.load(self.args.transformer_path, map_location=lambda storage, loc: storage), strict=False
            )

        for block in self.transformer.spatial_blocks:
            if getattr(block, 'enable_flash_attn', False):
                block.attn = block.attn.to(torch.bfloat16)
        for block in self.transformer.temporal_blocks:
            if getattr(block, 'enable_flash_attn', False):
                block.attn = block.attn.to(torch.bfloat16)
        
        param_numel, param_numel_trainable = get_model_numel(self.transformer)
        if self.args.grad_checkpoint:
            set_grad_checkpoint(self.transformer)

        self.ema = deepcopy(self.transformer)
        self.ema.requires_grad_(False)
        self.ema.eval()
        self.ema_shape_dict = record_model_param_shape(self.ema)
        update_ema_accelerator(self.ema, self.transformer, decay=0)

        self.scheduler = build_module(namespace_to_dict(self.args.scheduler), SCHEDULERS)

        self.mask_generator = MaskGenerator(namespace_to_dict(self.args.mask_ratios))

    def vae_encode(self, obs4vaewm, obs4vaewm_mask):
        x = rearrange(obs4vaewm, "B T H W C -> B C T H W")
        x = (x / 255.0 - 0.5) / 0.5
        x_z = self.vae.spatial_vae.encode(x)

        x_z = x_z * obs4vaewm_mask[:, None, :, None, None]

        z_list = []
        for i in range(0, x_z.shape[2], self.vae_max_compress_rate):
            x_z_bs = x_z[:, :, i : i + self.vae_max_compress_rate]
            posterior = self.vae.temporal_vae.encode(x_z_bs)
            z_list.append(posterior.sample())
        x = torch.cat(z_list, dim=2)
        x = (x - self.vae.shift) / self.vae.scale
        return x
    
    def vae_decode(self, x_z, batch_micro_frame_sizes):
        x_z = x_z * self.vae.scale.to(x_z.dtype) + self.vae.shift.to(x_z.dtype)
        x_z_list = []
        for i in range(0, x_z.size(2), 1):
            z_bs = x_z[:, :, i : i + 1]
            x_z_bs = self.vae.temporal_vae.decode(z_bs, num_frames=self.vae_max_compress_rate)
            x_z_list.append(x_z_bs)
        x_z = torch.cat(x_z_list, dim=2)
        x_z = self.vae.spatial_vae.decode(x_z)

        recon = []
        for x_i, micro_frame_size in zip(x_z, batch_micro_frame_sizes):
            x_i = x_i.reshape(x_i.shape[0], -1, self.vae_max_compress_rate, *x_i.shape[-2:])
            x_i = x_i[:, :, -micro_frame_size: ]
            x_i = x_i.reshape(x_i.shape[0], -1, *x_i.shape[-2:])
            x_i = F.pad(x_i, (0, 0, 0, 0, 0, max(0, self.sample_length - x_i.shape[1]), 0, 0), value=0, mode='constant')
            recon.append(x_i)
        recon = torch.stack(recon, dim=0)
        return recon

    def build_loss(self, x, action_cond, pad_mask_in, return_one_step_pred, detach_y_in_one_step_pred):
        model_args = {
            "y": action_cond,
            "image_cond": None,
            "height": torch.tensor(self.args.input_height),
            "width": torch.tensor(self.args.input_width),
            "num_frames": torch.tensor(self.args.num_frames),
        }
        mask = self.mask_generator.get_masks(x)
        model_args["x_mask"] = mask
        
        weights = self.get_motion_reinforced_weights(x) if self.args.motion_loss_weight else None
        loss_dict = self.scheduler.training_losses(
            pad_mask=pad_mask_in,
            model=self.transformer,
            x_start=x,
            model_kwargs=model_args,
            mask=mask,
            weights=weights,
            return_one_step_pred=return_one_step_pred, 
            detach_y_in_one_step_pred=detach_y_in_one_step_pred, 
        )
        loss_mask = torch.logical_and(pad_mask_in, mask).any(dim=-1)
        loss_valid = torch.where(loss_mask, loss_dict["loss"], 0.0)
        loss = loss_valid.sum() / loss_mask.sum()
        loss_dict['loss'] = loss
        loss_dict['x_mask'] = mask
        return loss_dict

class MyModelConfig(BaseModelConfig):
    codebook_levels: list[int] = [8, 5, 5, 5]
    wm_config_path: str | None = None
    train_idm: bool = True
    train_quantizer: bool = True
    train_fdm: bool = True
    train_wm_action: bool = True
    build_inverse_dynamics: bool = False
    train_inverse_dynamics: bool = False
    inverse_dynamics_loss_weight: float = 1.0
    quantizer_type: str = "vq"
    quantizer_loss_weight: float = 1.0
    use_difference_loss: bool = False
    difference_loss_weight: float = 1.0
    only_train_fdm: bool = False
    wm_config: dict = None

class MyModel(BasePretrainedModel, LatenActionMixin):

    def __init__(self, config) -> None:
        super().__init__(config)
        self.config = config

        if (
            not config.encoder_dino_input
        ):
            dino_pretrained = False
        else:
            dino_pretrained = True
        self.embed = Dinov2Embedder(
            pretrained=dino_pretrained, freeze=True, size=config.pretrained_dino_size
        )

        self.encoder = BaseModelEncoder(config, self.embed) if self.config.quantizer_type != "none" else BaseModelEncoderVAE(config, self.embed)
        if self.config.quantizer_type == "vq":
            self.quantizer = VectorQuantizer(n_e=config.n_codes, e_dim=config.action_latent_dim, beta=0.25)
        elif self.config.quantizer_type == "none":
            self.quantizer = nn.Identity()
        elif self.config.quantizer_type == "fsq":
            self.quantizer = FSQ(levels=config.codebook_levels, dim=config.action_latent_dim)
        elif self.config.quantizer_type == "ohq":
            self.quantizer = OneHotQuantizer(embedding_dim=config.num_learned_tokens, discrete_dim=config.action_latent_dim)
        else:
            raise ValueError(f"Unsupported quantizer type: {self.config.quantizer_type}")
        
        wm_config = config.wm_config
        wm_config = dict_to_namespace(wm_config)

        wm_config.latent_action_dim = config.action_latent_dim
        wm_config.latent_action_num_tokens = config.num_learned_tokens
        wm_config.train_wm_seq_length = config.d_t
        self.wm = WorldModel(wm_config)

        if self.config.build_inverse_dynamics:
            self.inverse_predictor = InversePredictor()
            if not self.config.train_inverse_dynamics:
                self.inverse_predictor.requires_grad_(False)
                self.inverse_predictor.eval()

        self.post_init()
        self.embed = Dinov2Embedder(
            pretrained=dino_pretrained, freeze=True, size=config.pretrained_dino_size
        )

        self.ever_trainable_params_name = [k for k, v in self.named_parameters() if v.requires_grad]

        if self.config.only_train_fdm:
            self.encoder.requires_grad_(False)
            self.encoder.eval()
            self.quantizer.requires_grad_(False)
            self.quantizer.eval()

    def update_training_state(self, component_state):
        def update_requires_grad(module, requires_grad, module_name=None):
            for k, v in module.named_parameters():
                if f'{module_name}.{k}' in self.ever_trainable_params_name or k in self.ever_trainable_params_name:
                    v.requires_grad = requires_grad
                else:
                    assert not v.requires_grad

        if component_state.train_idm:
            update_requires_grad(self.encoder, True, "encoder")
            self.encoder.train()
        else:
            update_requires_grad(self.encoder, False, "encoder")
            self.encoder.eval()
        
        if component_state.train_quantizer:
            update_requires_grad(self.quantizer, True, "quantizer")
            self.quantizer.train()
        else:
            update_requires_grad(self.quantizer, False, "quantizer")
            self.quantizer.eval()
        
        if component_state.train_fdm:
            update_requires_grad(self.wm.transformer, True, "wm.transformer")
            self.wm.transformer.train()
        else:
            update_requires_grad(self.wm.transformer, False, "wm.transformer")
            self.wm.transformer.eval()
        
        if component_state.train_wm_action:
            if hasattr(self.wm.transformer, "latent_action_embedder"):
                update_requires_grad(self.wm.transformer.latent_action_embedder, True, "wm.transformer.latent_action_embedder")
                self.wm.transformer.latent_action_embedder.train()
            if hasattr(self.wm.transformer, "act_block"):
                update_requires_grad(self.wm.transformer.act_block, True, "wm.transformer.act_block")
                self.wm.transformer.act_block.train()
            if hasattr(self.wm.transformer, "t_act_block") and self.wm.transformer.t_act_block is not None:
                update_requires_grad(self.wm.transformer.t_act_block, True, "wm.transformer.t_act_block")
                self.wm.transformer.t_act_block.train()
        else:
            if hasattr(self.wm.transformer, "latent_action_embedder"):
                update_requires_grad(self.wm.transformer.latent_action_embedder, False, "wm.transformer.latent_action_embedder")
                self.wm.transformer.latent_action_embedder.eval()
            if hasattr(self.wm.transformer, "act_block"):
                update_requires_grad(self.wm.transformer.act_block, False, "wm.transformer.act_block")
                self.wm.transformer.act_block.eval()
            if hasattr(self.wm.transformer, "t_act_block") and self.wm.transformer.t_act_block is not None:
                update_requires_grad(self.wm.transformer.t_act_block, False, "wm.transformer.t_act_block")
                self.wm.transformer.t_act_block.eval()
        
        self._current_component_state = component_state

    def get_current_training_config(self):
        if hasattr(self, '_current_component_state'):
            return {
                'train_idm': self._current_component_state.train_idm,
                'train_quantizer': self._current_component_state.train_quantizer,
                'train_fdm': self._current_component_state.train_fdm,
                'train_wm_action': self._current_component_state.train_wm_action
            }
        else:
            return {
                'train_idm': self.config.train_idm,
                'train_quantizer': self.config.train_quantizer,
                'train_fdm': self.config.train_fdm,
                'train_wm_action': self.config.train_wm_action
            }

    def init_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        if self.config.quantizer_type == "vq":
            nn.init.uniform_(
                self.quantizer.embedding.weight, -1.0 / self.config.n_codes, 1.0 / self.config.n_codes
            )
        elif self.config.quantizer_type == "fsq":
            self.quantizer.apply(_basic_init)
        elif self.config.quantizer_type == "none":
            pass
        elif self.config.quantizer_type == "ohq":
            pass
        else:
            raise ValueError(f"Unsupported quantizer type: {self.config.quantizer_type}")
        if self.config.train_inverse_dynamics:
            self.inverse_predictor.apply(_basic_init)

    @property
    def codebook_size(self) -> int:
        return getattr(self.quantizer, "n_e", None) 

    @property
    def codebook(self):
        return getattr(self.quantizer, "codebook", None)

    def get_unflatten_results(self, x_flat, clip_len, B, T):
        current_idx = 0
        res_tokens = torch.zeros(
            B, T-1, *x_flat.shape[1:],
            dtype=x_flat.dtype,
            device=x_flat.device
        )
        for i, length in enumerate(clip_len):
            len_to_copy = min(length - 1, T - 1)
            sequence_data = x_flat[current_idx : current_idx + len_to_copy]
            res_tokens[i, :len_to_copy] = sequence_data
            current_idx += (length - 1)
        return res_tokens

    def forward(
        self,
        obs4lam,
        obs4vaewm,
        obs4vaewm_mask,
        pad_mask_in: torch.Tensor | None = None,
        latent_action = None,
        only_return_actions = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        clip_len = pad_mask_in.sum(1).tolist()
        B, T = obs4lam.shape[:2]
        
        if latent_action is None:
            current_config = self.get_current_training_config()
            augment_type = self.config.augment_type if current_config['train_idm'] and self.encoder.training and not only_return_actions else "none"

            enc_clips = self.preprocess(obs4lam, augment_type=augment_type)
            latent_action = self.encoder(enc_clips, clip_len, right_padding=True)
        
        labels = None
        if self.config.quantizer_type != "none":
            if self.config.quantizer_type == "vq":
                action_tokens_flat, quantizer_loss, (_, _, indices) = self.quantizer(latent_action)
                indices_flat = indices.view(-1, self.config.num_learned_tokens)
                indices_labels = self.get_unflatten_results(indices_flat, clip_len, B, T)
                labels = indices_flat
            elif self.config.quantizer_type == "fsq":
                action_tokens_flat, indices = self.quantizer(latent_action.squeeze(1))
                action_tokens_flat = action_tokens_flat.unsqueeze(1)
                quantizer_loss= torch.tensor(0.0, device=action_tokens_flat.device)
            elif self.config.quantizer_type == "ohq":
                action_tokens_flat = self.quantizer(latent_action.squeeze(1))
                action_tokens_flat = action_tokens_flat.reshape(*action_tokens_flat.shape[:-2], -1).unsqueeze(1)
                quantizer_loss= torch.tensor(0.0, device=action_tokens_flat.device)
                indices = None
            else:
                raise ValueError(f"Unsupported quantizer type: {self.config.quantizer_type}")
        else:
            action_tokens_flat, action_dist_flat = latent_action
            prior_dist = D.Independent(D.Normal(torch.zeros_like(action_tokens_flat), torch.ones_like(action_tokens_flat)), 2)
            quantizer_loss = D.kl_divergence(action_dist_flat, prior_dist).mean()
            indices = None

        if only_return_actions:
            return ModelForwardOutput(
                loss=None,
                loss_dict=None,
                reconstructions=None,
                action_tokens=action_tokens_flat,
                label=labels,
                codebook_indices=indices,
            )
        
        action_tokens = self.get_unflatten_results(action_tokens_flat, clip_len, B, T)

        with torch.no_grad():
            x = self.wm.vae_encode(obs4vaewm, obs4vaewm_mask)

        action_cond = repeat(
            self.wm.transformer.latent_action_embedder.y_embedding,
            "N D -> B T 1 (N D)",
            B=x.shape[0],
            T=x.shape[2],
        ).clone()
        action_cond[:, 1:] = action_tokens

        return_one_step_pred, detach_y_in_one_step_pred = False, False
        if self.config.train_inverse_dynamics:
            return_one_step_pred, detach_y_in_one_step_pred = True, True
        elif self.config.use_difference_loss:
            return_one_step_pred, detach_y_in_one_step_pred = True, False
        wm_loss_dict = self.wm.build_loss(x, action_cond, pad_mask_in, return_one_step_pred, detach_y_in_one_step_pred )
        recon_loss = wm_loss_dict['loss']
        if return_one_step_pred:
            one_step_pred = wm_loss_dict["one_step_pred"]
        loss = recon_loss + self.config.quantizer_loss_weight * quantizer_loss

        if self.config.train_inverse_dynamics:
            assert self.config.quantizer_type == "vq", "Inverse dynamics loss only supports VQ quantizer currently"
            pred_action_logits = self.inverse_predictor(one_step_pred[:, :, :-1], one_step_pred[:, :, 1:])
            loss_fn = nn.CrossEntropyLoss(reduction='none')
            inverse_dynamics_loss = loss_fn(pred_action_logits.view(-1, pred_action_logits.shape[-1]), indices_labels.view(-1))
            inverse_dynamics_loss = inverse_dynamics_loss.view(pred_action_logits.shape[0], pred_action_logits.shape[1], -1).mean(-1)
            inverse_dynamics_loss = (inverse_dynamics_loss * pad_mask_in[:, 1:]).sum() / pad_mask_in[:, 1:].sum()
            loss += self.config.inverse_dynamics_loss_weight * inverse_dynamics_loss
            inverse_accuracy = ((torch.argmax(pred_action_logits, dim=-1) == indices_labels).float() * (pad_mask_in[:, 1:].unsqueeze(-1))).sum() \
                / (pad_mask_in[:, 1:].sum() * pred_action_logits.shape[2])
        elif self.config.use_difference_loss:
            x_mask = wm_loss_dict['x_mask']
            pred_diff = one_step_pred[:, :, 1:] - one_step_pred[:, :, :-1]
            gt_diff = x[:, :, 1:] - x[:, :, :-1]
            from .schedulers.iddpm.gaussian_diffusion import mean_flat
            diff_loss = mean_flat((pred_diff - gt_diff).pow(2), mask=(pad_mask_in[:, 1:] & x_mask[:, 1:]))
            loss_mask = torch.logical_and(pad_mask_in, x_mask).any(dim=-1)
            diff_loss_valid = torch.where(loss_mask, diff_loss, 0.0)
            diff_loss = diff_loss_valid.sum() / loss_mask.sum()
            loss += self.config.difference_loss_weight * diff_loss

        log_loss_dict = {
            "loss": loss,
            "loss/recon": recon_loss,
            f"loss/{self.config.quantizer_type}": quantizer_loss,
            "loss/inverse_dynamics": inverse_dynamics_loss if self.config.train_inverse_dynamics else torch.tensor(0.0, device=loss.device),
            "inverse_accuracy": inverse_accuracy if self.config.train_inverse_dynamics else torch.tensor(0.0, device=loss.device),
            "loss/difference_loss": diff_loss if self.config.use_difference_loss else torch.tensor(0.0, device=loss.device),
        }

        return ModelForwardOutput(
            loss=loss,
            loss_dict=log_loss_dict,
            reconstructions=None,
            action_tokens=action_tokens_flat,
            label=labels,
            codebook_indices=indices,
        )

    def post_process(self):
        update_ema_accelerator(self.wm.ema, self.wm.transformer, decay=self.wm.args.ema_decay)
    
    def train(self, mode: bool = True):
        super().train(mode)

        current_config = self.get_current_training_config()

        if current_config['train_fdm']:
            self.wm.transformer.train(mode)
        else:
            self.wm.transformer.eval()
            if current_config['train_wm_action']:
                self.wm.transformer.latent_action_embedder.train(mode)
                self.wm.transformer.act_block.train(mode)
                if getattr(self.wm.transformer, "t_act_block", None) is not None:
                    self.wm.transformer.t_act_block.train(mode)

        if current_config['train_idm']:
            self.encoder.train(mode)
        else:
            self.encoder.eval()
        
        if current_config['train_quantizer']:
            self.quantizer.train(mode)
        else:
            self.quantizer.eval()
        
        self.wm.vae.eval() 
        self.wm.ema.eval()

        return self

    def eval(self):
        return self.train(False)