import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torchvision.transforms as transforms
from .info_nce import InfoNCE


class VLAContrastiveLearningHead(pl.LightningModule):
    def __init__(self, 
                 pooling = 'max', 
                 action_encoder = None,
                 action_diffusion = None,
                 config = None):
        super().__init__()
        self.config = config
        self.action_encoder = action_encoder
        self.action_diffusion = action_diffusion
        self.hidden_size = config.n_embd
        self.log_image_num = 10

        # https://github.com/RoboFlamingo/RoboFlamingo/blob/79f2677602815750b3e66907ab4f99c3f03c1ca3/robot_flamingo/models/action_head.py#L406
        if pooling == 'max':
            # self.state_pool = nn.AdaptiveMaxPool1d(1)
            self.action_pool = nn.AdaptiveMaxPool1d(1)
        else:
            # self.state_pool = nn.AdaptiveAvgPool1d(1)
            self.action_pool = nn.AdaptiveAvgPool1d(1)

        # if self.fusion_mode == 'two_way':
        #     if pooling == 'max':
        #         self.gripper_1d_max_pool = nn.AdaptiveMaxPool1d(1)
        #     else:
        #         self.gripper_1d_max_pool = nn.AdaptiveAvgPool1d(1)

        self.cl_fc = nn.Linear(self.hidden_size, self.hidden_size)
        self.contrastive_loss = InfoNCE(temperature = 0.1, reduction = "mean")

    def prepare_inputs(self, prompt_embed, state_embed_list, action, prompt_attn_mask, state_attn_mask_list, segment_lengths,
                       negative_prompt_embed = None, 
                       negative_prompt_attn_mask = None,
                        ):
        """
        Assemble the input sequence for the trajectory transformer

        prompt_embed: (B, prompt_len, hidden)
        embed_segments: [(B, S, seg_len_1, hidden), ..., (B, S, seg_len_n, hidden)]
            E.g.,
                embed_segments == [base_embed, grip_embed, crop_embed, action_embed]
                    base_embed: (B, S, patch_num**2, hidden)
                    crop_pick_embed: (B, S, patch_num'**2, hidden)
                    crop_place_embed: (B, S, patch_num'**2, hidden)
                    grip_embed: (B, S, patch_num**2, hidden)
        action: (B, S, action_dim)
        """

        # ----------------- Original input -----------------
        N_Q = segment_lengths['batch_size']
        S = segment_lengths['timesteps']

        # cl_* are used to accumulate the embeddings for contrastive learning
        cl_prompt_embed = prompt_embed
        cl_prompt_attn_mask = prompt_attn_mask
        cl_state_embed_list = [embed for embed in state_embed_list]
        cl_state_attn_mask_list = [mask for mask in state_attn_mask_list]

        # action = action[:, :, 1:] # (B, S, action_dim-1)
        _, action = self.action_encoder.normalize_action(action)
        action_embed = self.action_encoder(action) # (B, S, 1, hidden)
        # cl_action_embed_list.append(action_embed)
        cl_action_embed = action_embed

        cl_embeds = {
            'prompt_embed': cl_prompt_embed,
            'prompt_attn_mask': cl_prompt_attn_mask,
            'state_embed_list': cl_state_embed_list,
            'state_attn_mask_list': cl_state_attn_mask_list,
            'action_embed': cl_action_embed,
        }

        # ----------------- Action positives -----------------
        # For positives, add some noise to action embeddings
        action_flat = action.view((N_Q * S, -1)) # (B * S, action_dim-1)
        action_xyz_rpy = action_flat[:, :-1] # (B * S, 6)
        # max_t = self.action_diffusion.num_timesteps
        max_t = 3
        t = (torch.randint(low=1, high=max_t, size=(action_flat.shape[0],))).to(action.device)
        action_xyz_ryp_t = self.action_diffusion.q_sample(
                                x_start=action_xyz_rpy, 
                                t=t,
                                # noise=None,
                            )
        action_xyz_ryp_t = action_xyz_ryp_t.view((N_Q, S, -1)) # (B, S, action_dim-1)
        action_gripper = action[:, :, -1:] # (B, S, 1)
        action_t = torch.cat((action_xyz_ryp_t, action_gripper), dim = -1) # (B, S, action_dim-1
        action_embed_noisy = self.action_encoder(action_t) # (B, S, 1, hidden)
        cl_embeds = self.accumulate_embeddings(cl_embeds, segment_lengths, new_action_embed = action_embed_noisy)
        N_P = action_embed_noisy.size(0)

        N_N = 0
        # ----------------- Prompt negatives -----------------
        # Add negative prompts
        if negative_prompt_embed is not None:
            cl_embeds = self.accumulate_embeddings(cl_embeds, segment_lengths, 
                                                   new_prompt_embed = negative_prompt_embed, 
                                                   new_prompt_attn_mask = negative_prompt_attn_mask)
            N_N += negative_prompt_embed.size(0)

        # ----------------- State negatives -----------------
        # For negatives, rotate the state_embeddings with a random integer shift on the batch dim
        batch_shift = torch.randint(1, N_Q, (1,)).item()
        new_state_embed_list = []
        new_state_attn_mask_list = []
        for i, state_embed in enumerate(state_embed_list):
            state_embed_neg_batch_shift = torch.cat((state_embed[batch_shift:, :, :, :], state_embed[:batch_shift, :, :, :]), dim = 0)
            state_attn_mask_neg_batch_shift = torch.cat((state_attn_mask_list[i][batch_shift:, :, :], state_attn_mask_list[i][:batch_shift, :, :]), dim = 0)
            new_state_embed_list.append(state_embed_neg_batch_shift)
            new_state_attn_mask_list.append(state_attn_mask_neg_batch_shift)
        cl_embeds = self.accumulate_embeddings(cl_embeds, segment_lengths, 
                                               new_state_embed_list = new_state_embed_list, 
                                               new_state_attn_mask_list = new_state_attn_mask_list)
        N_N += state_embed_neg_batch_shift.size(0)

        # ----------------- Action negatives -----------------
        # For negatives, rotate the action_embeddings with a random integer shift on the batch dim
        batch_shift = torch.randint(1, N_Q, (1,)).item()
        action_embed_neg_batch_shift = torch.cat((action_embed_noisy[batch_shift:, :, :], action_embed_noisy[:batch_shift, :, :]), dim = 0)
        cl_embeds = self.accumulate_embeddings(cl_embeds, segment_lengths, new_action_embed = action_embed_neg_batch_shift)
        N_N += action_embed_neg_batch_shift.size(0)

        # For negatives, rotate the action_embeddings with a random integer shift on the timestep dim
        timestep_shift = torch.randint(1, S, (1,)).item()
        action_embed_neg_ts_shift = torch.cat((action_embed_noisy[:, timestep_shift:, :, :], action_embed_noisy[:, :timestep_shift, :, :]), dim = 1)
        cl_embeds = self.accumulate_embeddings(cl_embeds, segment_lengths, new_action_embed = action_embed_neg_ts_shift)
        N_N += action_embed_neg_ts_shift.size(0)

        # for i, state_embed in enumerate(state_embed_list):
        #     state_embed_list[i] = state_embed.tile(len(cl_action_embed_list), 1, 1, 1)
        # for i, state_attn_mask in enumerate(state_attn_mask_list):
        #     state_attn_mask_list[i] = state_attn_mask.tile(len(cl_action_embed_list), 1, 1)

        # # Concatenate query, positive, and negative embeddings
        # action_embed_cl = torch.cat(cl_action_embed_list, dim = 0)
        assert cl_embeds['action_embed'].size(0) == N_Q + N_P + N_N

        # Update segment_lengths
        segment_lengths['positive_num'] = N_P
        segment_lengths['negative_num'] = N_N

        # return prompt_embed, state_embed_list, action_embed_cl, prompt_attn_mask, state_attn_mask_list, segment_lengths
        return cl_embeds["prompt_embed"], cl_embeds["state_embed_list"], cl_embeds["action_embed"], cl_embeds["prompt_attn_mask"], cl_embeds["state_attn_mask_list"], segment_lengths

    def accumulate_embeddings(self, cl_embeds, segment_lengths, new_prompt_embed = None, new_prompt_attn_mask = None, new_state_embed_list = None, new_state_attn_mask_list = None, new_action_embed = None):
        """
        Accumulate embeddings for contrastive learning
        """
        repeat_list = ['prompt_embed', 'prompt_attn_mask', 'state_embed_list', 'state_attn_mask_list', 'action_embed']
        if new_prompt_embed is not None:
            cl_embeds['prompt_embed'] = torch.cat((cl_embeds['prompt_embed'], new_prompt_embed), dim = 0)
            cl_embeds['prompt_attn_mask'] = torch.cat((cl_embeds['prompt_attn_mask'], new_prompt_attn_mask), dim = 0)
            repeat_list.remove('prompt_embed')
            repeat_list.remove('prompt_attn_mask')
        if new_state_embed_list is not None:
            for i, state_embed in enumerate(new_state_embed_list):
                cl_embeds['state_embed_list'][i] = torch.cat((cl_embeds['state_embed_list'][i], state_embed), dim = 0)
            for i, state_attn_mask in enumerate(new_state_attn_mask_list):
                cl_embeds['state_attn_mask_list'][i] = torch.cat((cl_embeds['state_attn_mask_list'][i], state_attn_mask), dim = 0)
            repeat_list.remove('state_embed_list')
            repeat_list.remove('state_attn_mask_list')
        if new_action_embed is not None:
            cl_embeds['action_embed'] = torch.cat((cl_embeds['action_embed'], new_action_embed), dim = 0)
            repeat_list.remove('action_embed')

        for key in repeat_list:
            if isinstance(cl_embeds[key], torch.Tensor):
                original_embed = cl_embeds[key][:segment_lengths['batch_size'], ...]
                cl_embeds[key] = torch.cat((cl_embeds[key], original_embed), dim = 0)
            elif isinstance(cl_embeds[key], list):
                for i in range(len(cl_embeds[key])):
                    original_embed = cl_embeds[key][i][:segment_lengths['batch_size'], ...]
                    cl_embeds[key][i] = torch.cat((cl_embeds[key][i], original_embed), dim = 0)

        return cl_embeds


    def forward(self, hidden_states, segment_lengths):
        N_Q = segment_lengths['batch_size'] # Number of queries, i.e., original data
        N_P = segment_lengths['positive_num']
        N_N = segment_lengths['negative_num']
        S = segment_lengths['timesteps']
        prompt_length = segment_lengths['prompt_length']

        # Remove prompt hidden states
        sa_hidden_states = hidden_states[:, prompt_length:, :].view((N_Q + N_P + N_N), S, -1, hidden_states.size(-1)) # (N_Q + N_P + N_N, S, state_len + action_len, hidden)
        # Remove action hidden states
        action_hidden_states = sa_hidden_states[:, :, -segment_lengths['action_length']:, :] # (B, S, action_len, hidden)
        action_hidden_states = action_hidden_states.reshape(((N_Q + N_P + N_N), -1, hidden_states.size(-1))) # (N_Q + N_P + N_N, S * action_len, hidden), concatenate actions from all timesteps

        action_pooled = self.action_pool(action_hidden_states.permute(0, 2, 1)).squeeze(-1) # (N_Q + N_P + N_N, hidden)
        action_pooled = self.cl_fc(action_pooled) # (N_Q + N_P + N_N, hidden)
        query = action_pooled[:N_Q, :] # (N_Q, hidden)
        positive = action_pooled[N_Q : N_Q+N_P, :] # (N_P, hidden)
        negative = action_pooled[N_Q+N_P:, :] # (N_N, hidden)

        loss_cl, acc_cl = self.contrastive_loss(query, positive, negative)

        return loss_cl, acc_cl
