import torch
import torch.nn as nn
from torchvision import transforms
from einops import rearrange, repeat
from dino_wm.models.contrastive_utils import *

class VWorldModel(nn.Module):
    def __init__(
        self,
        image_size,  # 224
        num_hist,
        num_pred,
        encoder,
        proprio_encoder,
        action_encoder,
        decoder,
        predictor,
        proprio_dim=0,
        action_dim=0,
        concat_dim=0,
        num_action_repeat=7,
        num_proprio_repeat=7,
        train_encoder=True,
        train_predictor=False,
        train_decoder=True,
        view_names=['view1'],
        contrastive_loss_type="triplet",
        contrastive_loss_level='image',
        normalize_for_contrastive=True,
        triplet_loss_margin=0.5,
        disable_reconstruction=False,
        minimize_recon_loss_pred=False,
        action_conditioned_time_contrastive=False,
        encoder_type='dino',
        use_layernorm=True,
    ):
        super().__init__()
        self.num_hist = num_hist
        self.num_pred = num_pred
        self.encoder = encoder
        self.proprio_encoder = proprio_encoder
        self.action_encoder = action_encoder
        self.decoder = decoder  # decoder could be None
        self.predictor = predictor  # predictor could be None
        self.train_encoder = train_encoder
        self.train_predictor = train_predictor
        self.train_decoder = train_decoder
        self.view_names = view_names
        self.contrastive_loss_type = contrastive_loss_type
        self.contrastive_loss_level = contrastive_loss_level
        self.normalize_for_contrastive = normalize_for_contrastive
        self.disable_reconstruction = disable_reconstruction
        self.minimize_recon_loss_pred = minimize_recon_loss_pred
        self.triplet_loss_margin = triplet_loss_margin
        self.action_conditioned_time_contrastive = action_conditioned_time_contrastive
        self.num_action_repeat = num_action_repeat
        self.num_proprio_repeat = num_proprio_repeat
        self.proprio_dim = proprio_dim * num_proprio_repeat 
        self.action_dim = action_dim * num_action_repeat 
        self.emb_dim = self.encoder.emb_dim * len(self.view_names) + (self.action_dim + self.proprio_dim) * (concat_dim) # Not used
        self.encoder_type = encoder_type

        print(f"num_action_repeat: {self.num_action_repeat}")
        print(f"num_proprio_repeat: {self.num_proprio_repeat}")
        print(f"proprio encoder: {proprio_encoder}")
        print(f"action encoder: {action_encoder}")
        print(f"proprio_dim: {proprio_dim}, after repeat: {self.proprio_dim}")
        print(f"action_dim: {action_dim}, after repeat: {self.action_dim}")
        print(f"emb_dim: {self.emb_dim}")
        print(f'image_size: {image_size}')
        print(f'encoder_type: {encoder_type}')

        self.concat_dim = concat_dim # 0 or 1
        assert concat_dim == 0 or concat_dim == 1, f"concat_dim {concat_dim} not supported."
        print("Model emb_dim: ", self.emb_dim)

        if "dino" in self.encoder.name or 'vae' in self.encoder.name:
            decoder_scale = 14  # from vqvae
            num_side_patches = image_size // decoder_scale
            self.encoder_image_size = num_side_patches * encoder.patch_size
            self.encoder_transform = transforms.Compose(
                [transforms.Resize(self.encoder_image_size)]
            )
            print(f'num_side_patches: {num_side_patches}')
            print(f'encoder_image_size: {self.encoder_image_size}')
        else:
            # set self.encoder_transform to identity transform
            self.encoder_transform = lambda x: x

        self.decoder_criterion = nn.MSELoss()
        self.decoder_latent_loss_weight = 0.25
        self.emb_criterion = nn.MSELoss()

        # if len(view_names) > 1: # and self.train_predictor:
        # One LayerNorm per view (applied before concatenation)
        if use_layernorm:
            self.per_view_norm = nn.ModuleDict({
                view_name: nn.LayerNorm(self.encoder.emb_dim, elementwise_affine=False)
                for view_name in view_names
            }).to("cuda")
            # One LayerNorm for the fused embedding (applied after concatenation)
            if len(view_names) > 1: # and self.train_predictor:
                total_dim = self.encoder.emb_dim * len(view_names)
                self.fusion_norm = nn.LayerNorm(total_dim, elementwise_affine=False).to("cuda")

    def train(self, mode=True):
        super().train(mode)
        if self.train_encoder:
            self.encoder.train(mode)
        if self.predictor is not None and self.train_predictor:
            self.predictor.train(mode)
        self.proprio_encoder.train(mode)
        self.action_encoder.train(mode)
        if self.decoder is not None and self.train_decoder:
            self.decoder.train(mode)

    def eval(self):
        super().eval()
        self.encoder.eval()
        if self.predictor is not None:
            self.predictor.eval()
        self.proprio_encoder.eval()
        self.action_encoder.eval()
        if self.decoder is not None:
            self.decoder.eval()

    def encode(self, obs, act): 
        """
        input :  obs (dict): "visual", "proprio", (b, num_frames, 3, img_size, img_size) 
        output:    z (tensor): (b, num_frames, num_patches, emb_dim)
        """
        z_dct = self.encode_obs(obs)
        act_emb = self.encode_act(act)
        if self.concat_dim == 0:
            z = torch.cat(
                    [z_dct['visual'], z_dct['proprio'].unsqueeze(2), act_emb.unsqueeze(2)], dim=2 # add as an extra token
                )  # (b, num_frames, num_patches + 2, dim)
        if self.concat_dim == 1:
            proprio_tiled = repeat(z_dct['proprio'].unsqueeze(2), "b t 1 a -> b t f a", f=z_dct['visual'].shape[2])
            proprio_repeated = proprio_tiled.repeat(1, 1, 1, self.num_proprio_repeat)
            act_tiled = repeat(act_emb.unsqueeze(2), "b t 1 a -> b t f a", f=z_dct['visual'].shape[2])
            act_repeated = act_tiled.repeat(1, 1, 1, self.num_action_repeat)
            z = torch.cat(
                [z_dct['visual'], proprio_repeated, act_repeated], dim=3
            )  # (b, num_frames, num_patches, dim + action_dim)
        return z
    
    def encode_act(self, act):
        act = self.action_encoder(act) # (b, num_frames, action_emb_dim)
        return act
    
    def encode_proprio(self, proprio):
        proprio = self.proprio_encoder(proprio)
        return proprio

    def encode_obs(self, obs):
        visual_embs = self.encode_obs_visual(obs['visual'])

        proprio = obs['proprio']
        proprio_emb = self.encode_proprio(proprio)
        return {"visual": visual_embs, "proprio": proprio_emb}

    def encode_obs_visual(self, obs_visual):
        if self.encoder_type == 'dino':
            view_embs = []
            b = None
            for view_name in self.view_names:
                imgs = obs_visual[view_name]
                # imgs: shape (B, T, 3, H, W)
                if b is None:
                    b = imgs.shape[0]
                imgs = rearrange(imgs, "b t ... -> (b t) ...")
                imgs = self.encoder_transform(imgs)  # e.g. resize if needed
                emb = self.encoder(imgs)             # shape (b*t, P, D)
                emb = rearrange(emb, "(b t) p d -> b t p d", b=b)
                # (Optional) Per-view normalization
                if hasattr(self, "per_view_norm"):
                    emb = self.per_view_norm[view_name](emb)
                view_embs.append(emb)
            visual_embs = torch.cat(view_embs, dim=-1)
            if hasattr(self, "fusion_norm"):
                visual_embs = self.fusion_norm(visual_embs)
        else:
            view_embs = self.encoder(obs_visual)
            for view_name in self.view_names:
                if hasattr(self, "per_view_norm"):
                    view_embs[view_name] = self.per_view_norm[view_name](view_embs[view_name])
            visual_embs = torch.cat([view_embs[view_name] for view_name in self.view_names], dim=-1)
            if hasattr(self, "fusion_norm"):
                visual_embs = self.fusion_norm(visual_embs)
            
        return visual_embs

    def predict(self, z):  # in embedding space
        T = z.shape[1]
        # reshape to a batch of windows of inputs
        z = rearrange(z, "b t p d -> b (t p) d")
        # (b, num_hist * num_patches per img, emb_dim)
        z = self.predictor(z)
        z = rearrange(z, "b (t p) d -> b t p d", t=T)
        return z

    def decode(self, z):
        z_obs, z_act = self.separate_emb(z)
        obs, diff = self.decode_obs(z_obs)
        return obs, diff

    def decode_obs(self, z_obs):
        b, t, n, total_e = z_obs["visual"].shape
        num_views = len(self.view_names)
        assert total_e % num_views == 0, "Total embedding dimension must be divisible by num_views."
        e = total_e // num_views

        # Reshape into (b, t, n, num_views, e) and then permute to (b, t, num_views, n, e)
        z_reshaped = z_obs["visual"].reshape(b, t, n, num_views, e).permute(0, 1, 3, 2, 4)

        visuals_dict = {}
        diffs = []

        # Loop over each view and decode its slice separately.
        for i, view_name in enumerate(self.view_names):
            # Slice out the i-th view: shape (b, t, n, e)
            z_view = z_reshaped[:, :, i]  
            # Flatten batch and time dimensions for the decoder: (b*t, n, e)
            visual_decoded, diff_view = self.decoder(z_view)
            # Reshape decoded images back to (b, t, 3, H, W)
            visual_decoded = rearrange(visual_decoded, "(b t) c h w -> b t c h w", b=b, t=t)
            visuals_dict[view_name] = visual_decoded
            diffs.append(diff_view)

        combined_diff = torch.mean(torch.stack(diffs))
        
        obs = {
            "visual": visuals_dict,       # dict with keys corresponding to self.view_names
            "proprio": z_obs["proprio"],  # passed through unchanged
        }
        return obs, combined_diff
    
    def compute_recon_loss(self, obs_pred, obs_tgt, criterion):
        loss = 0
        for view_name in self.view_names:
            loss += criterion(obs_pred[view_name], obs_tgt[view_name])
        return loss
    
    def separate_emb(self, z):
        if self.concat_dim == 0:
            z_visual, z_proprio, z_act = z[:, :, :-2, :], z[:, :, -2, :], z[:, :, -1, :]
        elif self.concat_dim == 1:
            z_visual, z_proprio, z_act = z[..., :-(self.proprio_dim + self.action_dim)], \
                                         z[..., -(self.proprio_dim + self.action_dim) :-self.action_dim],  \
                                         z[..., -self.action_dim:]
            # remove tiled dimensions
            z_proprio = z_proprio[:, :, 0, : self.proprio_dim // self.num_proprio_repeat]
            z_act = z_act[:, :, 0, : self.action_dim // self.num_action_repeat]
        z_obs = {"visual": z_visual, "proprio": z_proprio}
        return z_obs, z_act

    def forward(self, obs, act, multi_step_pos_imgs=None):
        loss = 0.0
        loss_components = {}
        z = self.encode(obs, act)
        z_src = z[:, : self.num_hist, :, :]  # (b, num_hist, num_patches, dim)
        z_tgt = z[:, self.num_pred :, :, :]  # (b, num_hist, num_patches, dim)
        visual_src = {view_name: obs['visual'][view_name][:, : self.num_hist, ...] for view_name in self.view_names}
        visual_tgt = {view_name: obs['visual'][view_name][:, self.num_pred :, ...] for view_name in self.view_names}

        if self.predictor is not None:
            z_pred = self.predict(z_src)
            if self.decoder is not None:
                obs_pred, diff_pred = self.decode(
                    z_pred #.detach()
                )  # recon loss should only affect decoder
                visual_pred = obs_pred['visual']
                recon_loss_pred = self.compute_recon_loss(visual_pred, visual_tgt, self.decoder_criterion)
                decoder_loss_pred = (
                    recon_loss_pred + self.decoder_latent_loss_weight * diff_pred
                )
                if self.minimize_recon_loss_pred:
                    # loss = loss + 0.5 * recon_loss_pred
                    loss = loss + 1.0 * recon_loss_pred
                loss_components["decoder_recon_loss_pred"] = recon_loss_pred
                loss_components["decoder_vq_loss_pred"] = diff_pred
                loss_components["decoder_loss_pred"] = decoder_loss_pred
            else:
                visual_pred = None

            # Compute loss for visual, proprio dims (i.e. exclude action dims)
            if self.concat_dim == 0:
                z_visual_loss = self.emb_criterion(z_pred[:, :, :-2, :], z_tgt[:, :, :-2, :].detach())
                z_proprio_loss = self.emb_criterion(z_pred[:, :, -2, :], z_tgt[:, :, -2, :].detach())
                z_loss = self.emb_criterion(z_pred[:, :, :-1, :], z_tgt[:, :, :-1, :].detach())
            elif self.concat_dim == 1:
                z_visual_loss = self.emb_criterion(
                    z_pred[:, :, :, :-(self.proprio_dim + self.action_dim)], \
                    z_tgt[:, :, :, :-(self.proprio_dim + self.action_dim)].detach()
                )
                z_proprio_loss = self.emb_criterion(
                    z_pred[:, :, :, -(self.proprio_dim + self.action_dim): -self.action_dim], 
                    z_tgt[:, :, :, -(self.proprio_dim + self.action_dim): -self.action_dim].detach()
                )
                z_loss = self.emb_criterion(
                    z_pred[:, :, :, :-self.action_dim], 
                    z_tgt[:, :, :, :-self.action_dim].detach()
                )

            loss = loss + z_loss
            loss_components["z_loss"] = z_loss
            loss_components["z_visual_loss"] = z_visual_loss
            loss_components["z_proprio_loss"] = z_proprio_loss
        else:
            visual_pred = None
            z_pred = None

        if self.decoder is not None and not self.disable_reconstruction:
            obs_reconstructed, diff_reconstructed = self.decode(
                z.detach()
            )  # recon loss should only affect decoder
            visual_reconstructed = obs_reconstructed["visual"]
            recon_loss_reconstructed = self.compute_recon_loss(visual_reconstructed, obs['visual'], self.decoder_criterion)
            decoder_loss_reconstructed = (
                recon_loss_reconstructed
                + self.decoder_latent_loss_weight * diff_reconstructed
            )

            loss_components["decoder_recon_loss_reconstructed"] = (
                recon_loss_reconstructed
            )
            loss_components["decoder_vq_loss_reconstructed"] = diff_reconstructed
            loss_components["decoder_loss_reconstructed"] = (
                decoder_loss_reconstructed
            )
            loss = loss + decoder_loss_reconstructed
        else:
            visual_reconstructed = None

        if self.action_conditioned_time_contrastive:
            assert self.predictor is not None, "Predictor must be defined for action-conditioned time contrastive loss."
            z_pos = self.encode_obs_visual(multi_step_pos_imgs)
            # z_pos = z_tgt[:, :, :, :-(self.proprio_dim + self.action_dim)]
            # print('z_tgt_visual ', z_tgt_visual.shape)
            # print('z_pos ', z_pos.shape)
            # assert torch.equal(z_tgt_visual, z_pos), "z_tgt_visual and z_pos must be equal for action-conditioned time contrastive loss."
            z_anchor = z_src[:, :, :, :-(self.proprio_dim + self.action_dim)]
            # print('z_anchor ', z_anchor.shape)
            z_pred_visual = z_pred[:, :, :, :-(self.proprio_dim + self.action_dim)].clone() #.detach()
            loss_contrastive = action_conditioned_time_contrastive_loss(z_anchor, z_pred_visual, z_pos)
            loss_components["loss_contrastive"] = loss_contrastive
            loss = loss + 0.0001 * loss_contrastive

        loss_components["loss"] = loss

        return z_pred, visual_pred, visual_reconstructed, loss, loss_components

    def replace_actions_from_z(self, z, act):
        act_emb = self.encode_act(act)

        if self.concat_dim == 0:
            # Create a new tensor without in-place assignment
            z_updated = torch.cat([
                z[:, :, :-1, :],  # everything except the last element
                act_emb.unsqueeze(2)  # the new actions
            ], dim=2)
        elif self.concat_dim == 1:
            act_tiled = repeat(act_emb.unsqueeze(2), "b t 1 a -> b t f a", f=z.shape[2])
            act_repeated = act_tiled.repeat(1, 1, 1, self.num_action_repeat)
            
            # concatenate along the feature dimension (last dimension)
            z_updated = torch.cat([
                z[..., :-self.action_dim],  # all but last action_dim
                act_repeated  # new action embeddings
            ], dim=-1)
        else:
            raise ValueError(f"Unsupported concat_dim {self.concat_dim}")
        
        return z_updated
