import robomimic.utils.file_utils as FileUtils
import argparse 
import os 
import torch.nn as nn
from torchvision import transforms
import torch 
import numpy as np
import math 
from image_models import ResNet18Dec, VQVAE
import einops

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, batch_first = True):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        self.batch_first = batch_first 

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        if self.batch_first:
            x = torch.transpose(x, 0, 1)
        x = x + self.pe[:x.size(0)]
        if self.batch_first:
            x = torch.transpose(x, 0, 1)
        return self.dropout(x)

# this is for action embedding 
class ActionEmbedding(nn.Module):
    def __init__(
        self,
        num_frames=16, # horizon
        tubelet_size=1,
        in_chans=8, # action_dim
        emb_dim=384, # output_dim
        use_3d_pos=False # always False for now
    ):
        super().__init__()

        # Map input to predictor dimension
        self.num_frames = num_frames
        self.tubelet_size = tubelet_size
        self.in_chans = in_chans
        self.emb_dim = emb_dim

        output_dim = emb_dim // num_frames

        # just downsampling 
        self.patch_embed = nn.Conv1d(
            in_chans,
            output_dim,
            kernel_size=tubelet_size,
            stride=tubelet_size)
        self.out_project = nn.Linear(num_frames * output_dim, emb_dim)


    def forward(self, x):
        # x: proprioceptive vectors of shape [B T D]
        x = x.permute(0, 2, 1) # [b, d, t]
        # x = self.conv_layers(x)
        x = self.patch_embed(x)
        x = einops.rearrange(x, "b d t -> b 1 (d t)")
        # x = x.permute(0, 2, 1)

        return x

class FinalStatePredictionDino(nn.Module):
    def __init__(self, action_dim, action_horizon, cameras, proprio = None, proprio_dim = None, reconstruction = True):
        super().__init__()
        chunk_size = 384 
        emb_dropout = 0 

        self.action_embedder = ActionEmbedding(num_frames = action_horizon, in_chans = action_dim, emb_dim = chunk_size)

        self.image_transform = transforms.Compose([ # assumes given 0 - 255 
            transforms.Resize((224, 224)),
            transforms.Normalize(
                mean=(123.675, 116.28, 103.53),
                std=(58.395, 57.12, 57.375),
            ), 
        ])

        self.state_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to("cuda")

        # self.start_token = torch.nn.Parameter(torch.zeros(1, chunk_size), requires_grad = True) #this is what we feed into the network for the token 
        # self.encoding_position_embedding = PositionalEncoding(chunk_size, max_len = len(cameras) * 256 + 1, batch_first = True)
        # self.decoding_position_embedding = PositionalEncoding(chunk_size, max_len = 257, batch_first = True)
        self.proprio = proprio 
        if proprio is not None:
            assert proprio_dim is not None 
            self.pos_embedding = nn.Parameter(torch.randn(1, 258, chunk_size)) # dim for the pos encodings
            self.proprio_embedder = nn.Linear(proprio_dim, chunk_size)

        else:
            self.pos_embedding = nn.Parameter(torch.randn(1, 257, chunk_size)) # dim for the pos encodings


        print("Freezing the encoder!")
        for parameter in self.state_encoder.parameters():
            parameter.requires_grad = False # freezing the encoder 

        self.state_encoder.eval()

        state_decoder_layer = nn.TransformerEncoderLayer(d_model=chunk_size, nhead=8, batch_first = True)# heads = 16 for a larger variant 
        self.state_decoder_transformer = nn.TransformerEncoder(state_decoder_layer, num_layers=6) # depth = 6 for a larger variant, taken from dinoWM 

        self.dropout = nn.Dropout(emb_dropout)
        # state_decoder_layer = nn.TransformerEncoderLayer(d_model=chunk_size, nhead=16, batch_first = True)# heads = 16 for a larger variant 
        # self.state_decoder_transformer = nn.TransformerEncoder(state_decoder_layer, num_layers=6) # depth = 6 for a larger variant, taken from dinoWM 


        # state_decoder_layer = nn.TransformerEncoderLayer(d_model=chunk_size, nhead=4, batch_first = True)# heads = 16 for a larger variant 
        # self.state_decoder_transformer = nn.TransformerEncoder(state_decoder_layer, num_layers=4) # depth = 6 for a larger variant, taken from dinoWM 
        
        # state_decoder_layer = nn.TransformerDecoderLayer(d_model=chunk_size, nhead = 8, batch_first = True) # heads = 16 for a larger variant 
        # self.state_decoder_transformer = nn.TransformerDecoder(state_decoder_layer, num_layers = 4) # depth = 6 for a larger variant, taken from dinoWM 

        self.cameras = cameras 
        self.mask = torch.nn.Transformer().generate_square_subsequent_mask(action_horizon) # TODO: verify that this is correct 
        
        self.reconstruction = reconstruction 
        if reconstruction:
            self.reconstruction_model = VQVAE(in_channel = 3, channel = 384, n_res_block = 4, n_res_channel = 128, emb_dim = 128, quantize = False)
            # self.reconstruction_model = ResNet18DecPatch()


    def trainable_parameters(self):
        count = 0 
        for parameter in self.parameters():
            if parameter.requires_grad:
                count += np.prod(parameter.size())
        return count 

    def compute_image_state_patches(self, state):
        patch_list = list()
        for camera in self.cameras:
            transformed_state = self.image_transform(state[camera])
            embed = self.state_encoder.forward_features(transformed_state)["x_norm_patchtokens"]
            patch_list.append(embed) # batch, patch, dim 
        return torch.concatenate(patch_list, dim = 1) #batch, patches (sequential), dim

    def forward(self, states, actions): # takes in 0-255 image, regular action chunk       
        image_embed = self.compute_image_state_patches(states) # [B, Patches, D]
        action_embed = self.action_embedder(actions) #[B, 1, D]
        if self.proprio is not None:
            proprio_embed = torch.unsqueeze(self.proprio_embedder(states[self.proprio]), axis = 1) # note that this is a bit jank but for experiments it's ok 
            combined_embed = torch.concatenate([image_embed, proprio_embed, action_embed], axis = 1)
        else:
            combined_embed = torch.concatenate([image_embed, action_embed], axis = 1)

        combined_embed += self.pos_embedding 
        predicted_state = self.state_decoder_transformer(combined_embed) # this is actually an ENCODER transformer 

        predicted_z_end = predicted_state[:, :256] # these are PATCHES! Ignore the action one 
        
        if self.reconstruction:
            reco_image = self.image_reconstruct(predicted_z_end.detach()) # CRITICAL: the image reconstructor doesn't influence the actual dynamics model 
            return predicted_z_end, reco_image 
        
        return predicted_z_end # returns zhat sequence and z sequence, then you can compute something like MSE loss 
    
    def image_reconstruct(self, embedding):
        return self.reconstruction_model(embedding)

    def state_embedding(self, state, normalize = False):
        # assuming that you get a single state with dictionary # [B X ...]
        # gets you the z embedding 
        embed = self.compute_image_state_patches(state)
        return embed 
        

    def state_action_embedding(self, state, actions, normalize = False):
        image_embed = self.compute_image_state_patches(state) # [B, Patches, D]
        action_embed = self.action_embedder(actions) #[B, 1, D]
        if self.proprio is not None:
            proprio_embed = torch.unsqueeze(self.proprio_embedder(state[self.proprio]), axis = 1) 
            combined_embed = torch.concatenate([image_embed, proprio_embed, action_embed], axis = 1)
        else:
            combined_embed = torch.concatenate([image_embed, action_embed], axis = 1)

        combined_embed += self.pos_embedding 
        predicted_state = self.state_decoder_transformer(combined_embed) # this is actually an ENCODER transformer 
        predicted_z_end = predicted_state[:, :256]  
        return predicted_z_end 

# using dino CLS token during development 
class FinalStatePredictionDinoCLS(nn.Module):
    def __init__(self, action_dim, action_horizon, cameras, proprio = None, proprio_dim = None, reconstruction = True):
        super().__init__()
        repr_dim = 384
        # repr_dim = 128

        assert proprio is None, "proprioception not impelemtned yet!"

        self.action_projector = nn.Sequential(
            nn.Linear(action_dim, repr_dim),
            # nn.ReLU(),
            # nn.Linear(repr_dim, repr_dim),
        )
        self.image_transform = transforms.Compose([ # assumes given 0 - 255 
                    transforms.Resize((224, 224)),
                    transforms.Normalize(
                        mean=(123.675, 116.28, 103.53),
                        std=(58.395, 57.12, 57.375),
                    ), 
        ])

        self.state_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to("cuda")

        print("Freezing the encoder!")
        for parameter in self.state_encoder.parameters():
            parameter.requires_grad = False # freezing the encoder 


        self.prediction_token = torch.nn.Parameter(torch.zeros(1, repr_dim), requires_grad = True) #this is what we feed into the network for the final output prediction
        self.decoding_position_embedding = PositionalEncoding(repr_dim, max_len = action_horizon + 1, batch_first = True)
    
        state_decoder_layer = nn.TransformerDecoderLayer(d_model=repr_dim, nhead = 8, batch_first = True)
        self.state_decoder_transformer = nn.TransformerDecoder(state_decoder_layer, num_layers = 6)

        self.cameras = cameras 
        # self.mask = torch.nn.Transformer().generate_square_subsequent_mask(action_horizon + 1) 
        if reconstruction:
            self.decoder = ResNet18Dec(z_dim = repr_dim)


    def trainable_parameters(self):
        count = 0 
        for parameter in self.parameters():
            if parameter.requires_grad:
                count += np.prod(parameter.size())
        return count 

    def compute_image_state_patches(self, state):
        patch_list = list()
        for camera in self.cameras:
            transformed_state = self.image_transform(state[camera])
            embed = self.state_encoder.forward_features(transformed_state)["x_norm_clstoken"]
            patch_list.append(embed) # batch, patch, dim 
        return torch.concatenate(patch_list, dim = 1) #batch, patches (sequential), dim


    def forward(self, states, actions): # takes in 0-255 image, regular action chunk       
        embed = self.compute_image_state_patches(states)
        projected_actions = self.action_projector(actions) # B X S X 128 
        projected_actions = self.decoding_position_embedding(projected_actions)
        prediction_token = torch.tile(self.prediction_token, dims = (embed.shape[0], 1)) #tiling for batch 
        prediction_token = torch.unsqueeze(prediction_token, dim = 1)
        projected_actions = torch.concatenate((projected_actions, prediction_token), dim = 1)
        embed = torch.unsqueeze(embed, dim = 1)
        predicted_final_state = self.state_decoder_transformer(projected_actions, memory = embed)[:, -1] # , tgt_mask=self.mask, tgt_is_causal = True)[:, -1] # causal transformer 
        reco_final_state = self.decoder(predicted_final_state.detach()) # CRITICAL: NOT PASSSING GRADIENT TO RECONSTRUCTOR
        return predicted_final_state, reco_final_state # returns zhat sequence and z sequence, then you can compute something like MSE loss 

    
    def image_reconstruct(self, embedding):
        return self.decoder(embedding)

    def state_embedding(self, state, normalize = False):
        s_embedding = self.compute_image_state_patches(state)
        if normalize: 
            return torch.nn.functional.normalize(s_embedding, dim = 1)
        return s_embedding

    def state_action_embedding(self, state, actions, normalize = False):
        assert not normalize, "defunct feature!"

        embed = self.compute_image_state_patches(state)
        projected_actions = self.action_projector(actions) # B X S X 128 
        projected_actions = self.decoding_position_embedding(projected_actions)
        prediction_token = torch.tile(self.prediction_token, dims = (embed.shape[0], 1)) #tiling for batch 
        prediction_token = torch.unsqueeze(prediction_token, dim = 1)
        projected_actions = torch.concatenate((projected_actions, prediction_token), dim = 1)
        embed = torch.unsqueeze(embed, dim = 1)
        predicted_final_state = self.state_decoder_transformer(projected_actions, memory = embed)[:, -1] # , tgt_mask=self.mask, tgt_is_causal = True)[:, -1] # causal transformer 
        return predicted_final_state
       


# this does direct classification for the pymunk toy environemnt 
class FinalStateClassification(nn.Module):
    def __init__(self, action_dim, action_horizon, cameras, state_vae, classes):
        super().__init__()
        repr_dim = 64

        self.action_projector = nn.Sequential(
            nn.Linear(action_dim, repr_dim),
            nn.ReLU(),
            nn.Linear(repr_dim, repr_dim),
        )

        self.state_vae = state_vae # must be a loaded model 

        print("I'M DELIBERATELY NOT FREEZING THE ENCODER!")
        # print("Freezing the encoder!")
        # print("Froze ", self.state_vae.trainable_parameters(), " parameters")
        # for parameter in self.state_vae.parameters():
        #     parameter.requires_grad = False # freezing the encoder 


        self.prediction_token = torch.nn.Parameter(torch.zeros(1, repr_dim), requires_grad = True) #this is what we feed into the network for the final output prediction
        self.decoding_position_embedding = PositionalEncoding(repr_dim, max_len = action_horizon + 1, batch_first = True)
     
        state_decoder_layer = nn.TransformerDecoderLayer(d_model=repr_dim, nhead = 4, batch_first = True)
        self.state_decoder_transformer = nn.TransformerDecoder(state_decoder_layer, num_layers = 4)

        self.prediction_head = nn.Linear(repr_dim, classes)

        self.cameras = cameras 
        # self.mask = torch.nn.Transformer().generate_square_subsequent_mask(action_horizon + 1) 

    def unfreeze(self):
        print("UNFREEZING the encoder!")
        for parameter in self.state_vae.parameters():
            parameter.requires_grad = True # freezing the encoder 
        print("Unfroze ", self.state_vae.trainable_parameters(), " parameters")
       
    def trainable_parameters(self):
        count = 0 
        for parameter in self.parameters():
            if parameter.requires_grad:
                count += np.prod(parameter.size())
        return count 

    def compute_image_state_patches(self, state):
        patch_list = list()
        for camera in self.cameras:
            assert torch.max(state[camera]) > 1, "you are feeding in an already-normalized image"
            transformed_state = state[camera] / 255 # assuming that the camera is 0-255 
            embed = self.state_vae.encode(transformed_state)
            patch_list.append(embed) # batch, patch, dim 
        return torch.concatenate(patch_list, dim = 1) #batch, patches (sequential), dim

    def forward(self, states, actions): # takes in 0-255 image, regular action chunk       
        embed = self.compute_image_state_patches(states)
        # print('ABLATION')
        # actions = torch.zeros_like(actions)
        projected_actions = self.action_projector(actions) # B X S X 128 
        projected_actions = self.decoding_position_embedding(projected_actions)
        # print("EXPERIMENT: NOT FREEZING ENCODER")

        prediction_token = torch.tile(self.prediction_token, dims = (embed.shape[0], 1)) #tiling for batch 
        prediction_token = torch.unsqueeze(prediction_token, dim = 1)
        projected_actions = torch.concatenate((projected_actions, prediction_token), dim = 1)
        embed = torch.unsqueeze(embed, dim = 1)
        predicted_latent = self.state_decoder_transformer(projected_actions, memory = embed)[:, -1] # , tgt_mask=self.mask, tgt_is_causal = True)[:, -1] # causal transformer 
        output_logit = self.prediction_head(predicted_latent)

        return output_logit 

    def state_embedding(self, state, normalize = False):
        raise Exception("Not valid for this model!")

    def state_action_embedding(self, state, actions, normalize = False):
        raise Exception("Not valid for this model!")

