import os
import collections.abc
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
import torch.nn.functional as F
from torchmetrics.functional.classification import multiclass_accuracy
from transformers import ViTConfig, AutoImageProcessor
# from .transformer_encoder import Encoder
from .diffusion.d3pm import make_diffusion
from .diffusion.ddpm import GaussianDiffusion


class ActionEncoder(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config

        if config.use_continuous_action:
            self.action_embedding = nn.Sequential(
                nn.Linear(1, config.n_embd * 4),
                nn.SiLU(),
                nn.Dropout(0.1),
                nn.Linear(config.n_embd * 4, config.n_embd),
                nn.SiLU(),
                nn.Dropout(0.1),
                nn.LayerNorm(config.n_embd),
            )
            self.action_type = "continuous_action"
        else:
            self.action_embedding = nn.Embedding(config.continuous_vocab_size, config.n_embd)
            self.action_type = "curr_action"
        self.tt_embedding = nn.Parameter(torch.zeros(1, 1, config.action_dim-2, config.n_embd))
        self.gripper_embedding = nn.Embedding(2, config.n_embd)
        # self.action_type = next((feature.name for feature in config.feature_list if "action" in feature.name), "curr_action")
        print(f"Action type: {self.action_type}")

    def forward(self, action):
        """
        actions as a batch of strings
        """
        # actions = batch_inputs[self.action_type] # (B, S, action_dim)
        B, S = action.shape[:2]
        if action.size(1) == 0:
            return torch.zeros((action.shape[:3] + (self.config.n_embd,)), device = self.device)
            # return torch.zeros((action.size(0), 0, action.size(2), self.config.n_embd), device = self.device)
        else:
            if self.config.use_continuous_action:
                assert action.dim() == 3 # (B, S, action_dim-1)
                pose = action[:, :, :-1] # (B, S, action_dim-2)
                pose = pose.view(-1, pose.size(-1), 1).float() # (B*S, action_dim-2, 1)
                pose_embed = self.action_embedding(pose) # (B*S, action_dim-2, hidden)
                pose_embed = pose_embed.view(B, S, -1, pose_embed.size(-1)) # (B, S, action_dim-2, hidden)
                pose_embed = pose_embed + self.tt_embedding # (B, S, action_dim-2, hidden)

                gripper = action[:, :, -1:] # (B, S, 1)
                gripper = ((gripper > 0) * 1).long()
                gripper_embed = self.gripper_embedding(gripper) # (B, S, 1, hidden)

            else:
                pose = action[:, :, :-1] # (B, S, action_dim-2)
                pose_embed = self.action_embedding(pose) # (B, S, action_dim-2, hidden)
                gripper = action[:, :, -1:] # (B, S, 1)
                gripper_embed = self.gripper_embedding(gripper) # (B, S, 1, hidden)

            action_embed = torch.cat([pose_embed, gripper_embed], dim = -2) # (B, S, action_dim-1, hidden)
        # B, S = actions.shape[:2]
        # return torch.zeros((B, S, 0, self.action_embedding.embedding_dim), device = self.device) # (B, S, 0, hidden) placeholder for not using actions
        return action_embed

    # def normalize_mode(self, mode):
    #     # Map mode to 0 or 1, moving pad
    #     return mode - 1

    def normalize_action(self, action):
        """
        action: dim == -1: [mode, x, y, z, r, p, y, gripper]

        vocab_size == 2051, discrete_vocab_size == 1024, continuous_vocab_size == 1024
        idx:    0    1       2              1024       1025         2048       2049   2050
        vocab: [pad, disc_0, disc_1,   ..., disc_1023, cont_0, ..., cont_1023, start, end]
                        Move    Terminate                 -1           1
        """
        if self.config.use_continuous_action:
            mode = action[..., :1] # (B, S, 1)
            mode = (mode - 1).long()
            action = action[..., 1:] # (B, S, action_dim-1)
            action[..., 3:6] = action[..., 3:6] / np.pi
            action[..., -1] = (action[..., -1] + 1) / 2
            action = action.float()
            return mode, action
        else:
            mode = action[..., 0] # (B, S)
            mode = mode - 1
            action = action[..., 1:] # (B, S, action_dim-1)
            action = action - self.config.discrete_vocab_size - 1
            action[..., -1] = action[..., -1] / (self.config.continuous_vocab_size - 1)
            return mode, action

    def denormalize_action(self, mode, action):
        mode = mode + 1
        if self.config.use_continuous_action:
            action[..., 3:6] = action[..., 3:6] * np.pi
            action[..., -1] = action[..., -1] * 2 - 1
        else:
            action[..., -1] = action[..., -1] * (self.config.continuous_vocab_size - 1)
            action = action + self.config.discrete_vocab_size + 1
        action_with_mode = torch.cat([mode, action], dim = -1)
        return action_with_mode


# class ActionEncoder(pl.LightningModule):
#     def __init__(self, config):
#         super().__init__()
#         self.config = config

#         self.encoder = Encoder(
#             src_vocab_size = config.continuous_vocab_size,
#             d_model = config.n_embd,
#             n_heads = 3,
#             d_k = 64,
#             d_v = 64,
#             d_ff = config.n_embd,
#             n_layers = 3
#         )
#         # self.action_type = next((feature.name for feature in config.feature_list if "action" in feature.name), "curr_action")
#         self.action_type = "curr_action"
#         print(f"Action type: {self.action_type}")

#     def forward(self, actions):
#         """
#         actions as a batch of strings
#         """
#         # if type(batch_inputs) == dict:
#         #     if type(batch_inputs[self.action_type]) == list:
#         #         actions = batch_inputs[self.action_type][0]
#         #     else:
#         #         actions = batch_inputs[self.action_type] # (B, S, action_dim)
#         # else:
#         #     actions = batch_inputs
#         B, S, _ = actions.shape
#         actions = actions.view(-1, actions.size(-1)) # (B*S, action_dim-1)
#         # actions = actions[:, 1:] # Remove the action mode token
#         # actions = self.normalize_action(actions) # (B*S, action_dim)
#         action_embed = self.encoder(actions) # (B*S, 1, hidden)
#         return action_embed.view(B, S, 1, -1) # (B, S, 1, hidden)

#     def normalize_mode(self, mode):
#         # Map mode to 0 or 1, moving pad
#         return mode - 1

#     def denormalize_mode(self, mode):
#         return mode + 1

#     def normalize_action(self, action):
#         return action - self.config.discrete_vocab_size - 1

#     def denormalize_action(self, action):
#         return action + self.config.discrete_vocab_size + 1


class ActionDecoder(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        if config.use_continuous_action:
            self.action_type = "continuous_action"
        else:
            self.action_type = "curr_action"
        # self.action_type = next((feature.name for feature in config.feature_list if "action" in feature.name), "curr_action")

        if config.use_continuous_action:
            if config.use_action_diffusion:
                self.mode_head = nn.Linear(config.n_embd, 2)
                # self.action_head = nn.Linear(config.n_embd, 1)
                self.action_head_list = nn.ModuleList([nn.Linear(config.n_embd, 1) for _ in range(config.action_dim-1)])
            else:
                self.mode_head = nn.Linear(config.n_embd, 2)
                self.pose_head = nn.Linear(config.n_embd, 1)
                self.pose_head_list = nn.ModuleList([nn.Linear(config.n_embd, 1) for _ in range(config.action_dim-2)])
                self.gripper_head = nn.Linear(config.n_embd, 2)
        else:
            self.mode_head = nn.Linear(config.n_embd, 2)
            # self.action_head = nn.Linear(config.n_embd, config.continuous_vocab_size)
            self.pose_head_list = nn.ModuleList([nn.Linear(config.n_embd, config.continuous_vocab_size) for _ in range(config.action_dim-2)])
            self.gripper_head = nn.Linear(config.n_embd, 2)

        ########################### INFERENCE ############################

        # # Mask out invalide logits
        # self.action_pred_mask = torch.zeros([1, config.action_dim, config.vocab_size-2], dtype = torch.long)
        # self.action_pred_mask[:, :, 0] = 1 # Mask out the padding token, vocab[0]
        # # # ------------------ Action mode ------------------
        # # self.action_pred_mask[:, 0, 3:] = 1 # Keep 1 and 2
        # # ------------------ Action xyz, rpy ------------------
        # self.action_pred_mask[:, :6, :config.discrete_vocab_size+1] = 1 # Mask out discrete actions
        # # self.action_pred_mask[:, 1:7, -2:] = 1 # Mask out start and end tokens
        # # ------------------ Action gripper ------------------
        # self.action_pred_mask[:, 6, :] = 1
        # self.action_pred_mask[:, 6, config.discrete_vocab_size + 1] = 0 # Close gripper == -1
        # self.action_pred_mask[:, 6, config.discrete_vocab_size + config.continuous_vocab_size] = 0 # Open gripper == 1
        # self.action_pred_mask = self.action_pred_mask.bool()

    def forward(self, hidden_states, segment_lengths):
        """
        actions as a batch of strings
        """
        B, S = segment_lengths["batch_size"], segment_lengths["timesteps"]
        sa_hidden_states = hidden_states[:, segment_lengths['prompt_length']:, :]

        # Diffuse the noisy action, only diffuse the last timestep
        # next_a_hidden_states = sa_hidden_states.view(B, S, -1, sa_hidden_states.size(-1))[:, -1, -1, :]
        # action_logits_list = [head(next_a_hidden_states).unsqueeze(-2) for head in self.action_pred_heads] # head(hidden_states): (B, S, continuous_vocab_size) -- unsqueeze -> (B, S, 1, continuous_vocab_size)
        # action_logits = torch.cat(action_logits_list, dim = -2) # (B, 1, action_dim-1, continuous_vocab_size)

        sa_hidden_states = sa_hidden_states.view(B, S, -1, sa_hidden_states.size(-1))
        if self.config.use_action_diffusion:
            next_a_hidden_states = sa_hidden_states[:, -1, -segment_lengths['action_length']:, :] # (B, action_dim-1, hidden)
        else:
            s_len = segment_lengths['state_length']
            q_len = segment_lengths['query_length']
            next_a_hidden_states = sa_hidden_states[:, :, s_len: s_len + q_len, :] # (B, S, action_dim-1, hidden)

        if self.config.use_continuous_action:
            if self.config.use_action_diffusion:
                pred_action_list = []
                for i in range(self.config.action_dim-1):
                    pred_action_dim = self.action_head_list[i](next_a_hidden_states[:, i, :]) # (B, 1)
                    pred_action_list.append(pred_action_dim)
                decoded_action = torch.cat(pred_action_list, dim = -1) # (B, action_dim-1)
                return decoded_action
            else:
                pred_pose_list = []
                for i in range(self.config.action_dim-2):
                    pred_pose_dim = self.pose_head_list[i](next_a_hidden_states[:, :, i, :]) # (B, S, 1)
                    pred_pose_list.append(pred_pose_dim)
                pose_logits = torch.cat(pred_pose_list, dim = -1) # (B, S, action_dim-2)
                gripper_logits = self.gripper_head(next_a_hidden_states[:, :, 6:, :]) # (B, S, 1, 2)
                return pose_logits, gripper_logits
        else:
            pred_pose_list = []
            for i in range(self.config.action_dim-2):
                pred_pose_dim = self.pose_head_list[i](next_a_hidden_states[:, :, i, :]) # (B, S, vocab_size)
                pred_pose_dim = pred_pose_dim.unsqueeze(-2) # (B, S, 1, vocab_size)
                pred_pose_list.append(pred_pose_dim)
            pose_logits = torch.cat(pred_pose_list, dim = -2) # (B, S, action_dim-2, vocab_size)
            gripper_logits = self.gripper_head(next_a_hidden_states[:, :, 6:, :]) # (B, S, 1, 2)
            return pose_logits, gripper_logits

    def predict_mode(self, hidden_states, segment_lengths):
        """
        actions as a batch of strings
        """
        B, S = segment_lengths["batch_size"], segment_lengths["timesteps"]
        sa_hidden_states = hidden_states[:, segment_lengths['prompt_length']:, :]

        # Use last state token to predict next action
        if self.config.use_action_diffusion:
            next_a_hidden_states = sa_hidden_states.view(B, S, -1, sa_hidden_states.size(-1))[:, -1, segment_lengths["state_length"]-1, :] # (B, S, hidden)
        else:
            next_a_hidden_states = sa_hidden_states.view(B, S, -1, sa_hidden_states.size(-1))[:, :, segment_lengths["state_length"]-1, :] # (B, S, hidden)

        mode_logits = self.mode_head(next_a_hidden_states) # (B, S, 2)

        return mode_logits

    def calculate_loss(self, decoded_action, true_action):
        if type(decoded_action) == tuple:
            pose_logits, gripper_logits = decoded_action
            B, S = pose_logits.size(0), pose_logits.size(1)
            if self.config.use_action_diffusion:
                # Calculate loss
                pred_pose = pose_logits.view(B*S, -1)
                true_pose = true_action[:, :, :-1].view(B*S, -1)
                loss_pose = F.mse_loss(pred_pose, true_pose, reduce=True)

                gripper_logits = gripper_logits.view(B*S, -1)
                true_gripper = true_action[:, :, -1].view(B*S).long()
                loss_gripper = F.cross_entropy(gripper_logits, true_gripper)

                loss = (loss_pose * (self.config.action_dim - 2) + loss_gripper) / (self.config.action_dim - 1)

                acc_gripper = multiclass_accuracy(gripper_logits, true_gripper, num_classes = gripper_logits.size(-1), average = 'micro')
                accuracy = acc_gripper.detach().cpu().item()

            else:
                pred_pose = pose_logits.view(-1, pose_logits.size(-1)) # (B*S*(action_dim), vocab_size)
                # Construct training targets and calculate loss
                true_pose = true_action[:, :, :-1].reshape(-1) # (B*S*(action_dim))
                loss_pose = F.cross_entropy(pred_pose, true_pose) # Fine-tuning loss

                gripper_logits = gripper_logits.view(B*S, -1)
                true_gripper = true_action[:, :, -1].view(B*S)
                loss_gripper = F.cross_entropy(gripper_logits, true_gripper)

                loss = (loss_pose * (self.config.action_dim - 2) + loss_gripper) / (self.config.action_dim - 1)

                acc_pose = multiclass_accuracy(pred_pose, true_pose, num_classes = pred_pose.size(-1), average = 'micro')
                acc_gripper = multiclass_accuracy(gripper_logits, true_gripper, num_classes = gripper_logits.size(-1), average = 'micro')
                accuracy = (acc_pose * (self.config.action_dim - 2) + acc_gripper) / (self.config.action_dim - 1)
                accuracy = accuracy.detach().cpu().item()
        else:
            pred_mode = decoded_action.view(-1, decoded_action.size(-1)) # (B*S*(action_dim), vocab_size)
            # Construct training targets and calculate loss
            true_mode = true_action.reshape(-1) # (B*S*(action_dim))
            loss = F.cross_entropy(pred_mode, true_mode) # Fine-tuning loss
            accuracy = multiclass_accuracy(pred_mode, true_mode, num_classes = pred_mode.size(-1), average = 'micro')
            accuracy = accuracy.detach().cpu().item()
        return loss, accuracy

    def predict_next_action(self, decoded_action):
        pose_logits, gripper_logits = decoded_action
        if self.config.use_continuous_action:
            pred_pose = pose_logits[:, -1, ...]
            pred_gripper = gripper_logits[:, -1, ...].argmax(dim = -1)
            predicted_action = torch.cat([pred_pose, pred_gripper], dim = -1) # (B, S, action_dim)
        else:
            # # next_action_logits = action_logits[:, -1, :, :] # (B, action_dim, vocab_size)
            # # next_action_logits.masked_fill_(self.action_pred_mask.to(self.device), -1e9)
            # # predicted_action = next_action_logits.argmax(dim = -1) # (B, action_dim)
            # predicted_action = decoded_action.argmax(dim = -1) # (B, action_dim)
            pred_pose = pose_logits[:, -1, ...].argmax(dim = -1)
            pred_gripper = gripper_logits[:, -1, ...].argmax(dim = -1)
            predicted_action = torch.cat([pred_pose, pred_gripper], dim = -1)

        return predicted_action


class ActionDiffusion(pl.LightningModule):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.hidden_size = config.n_embd

        if config.use_continuous_action:
            self.diffusion = GaussianDiffusion(
                timesteps = 10,
            )
        else:
            diffusion_config = {
                'args': {
                    'model_output': 'logits',
                    'num_pixel_vals': config.continuous_vocab_size,
                },
                'diffusion_betas': {
                    'type': 'linear',
                    'start': 1e-4,
                    'stop': 0.02,
                    'num_timesteps': 10
                },
                'model_prediction': 'x_start',
                'transition_mat_type': 'gaussian',
                'transition_bands': None,
                'loss_type': 'hybrid',
                'hybrid_coeff': 0.001
            }

            self.diffusion = make_diffusion(diffusion_config)

        # Embedding for timestep
        self.t_embedding = nn.Embedding(self.diffusion.num_timesteps, self.hidden_size)
        # self.denoiser = nn.Linear(self.hidden_size, max_range * 2)

    # def prepare_noisy_action(self, batch_inputs):
    #     """
    #     batch_inputs["xxx_action"]: (B, S, action_dim)  * 1 is the action mode

    #     noisy_action    (B, S, action_dim)
    #     noise           (B, S, action_dim)
    #     t               (B, S)
    #     """
    #     for feature in self.feature_list:
    #         if feature.name == "prev_action":
    #             action = batch_inputs[feature.name]
    #             action = action.view(-1, action.size(-1)) # (B*S, action_dim)
    #             noisy_action, noise, t = self.diffusion.q_sample(action)
    #             noisy_action = noisy_action.view(batch_inputs[feature.name].shape) # (B, S, action_dim)
    #             noise = noise.view(batch_inputs[feature.name].shape)
    #             t = t.view(batch_inputs[feature.name].shape[:2])
    #             batch_inputs[feature.name] = [noisy_action, noise, t]
    #         else:
    #             continue
    #     return batch_inputs

    @property
    def num_timesteps(self):
        return self.diffusion.num_timesteps

    # @property
    # def num_pixel_vals(self):
    #     return self.diffusion.num_pixel_vals

    def add_diffusion_timestep(self, action_embed, t):
        """
        action_embed: (B, S, action_dim, hidden)
        t: (B, S)
        """
        t_embed = self.t_embedding(t) # (B, S, hidden)
        t_embed = t_embed.unsqueeze(-2).unsqueeze(-2)
        action_embed = action_embed + t_embed # (B, S, action_dim, hidden)
        return action_embed

    def p_sample_loop(self, model, shape):
        if self.config.use_continuous_action:
            return self.diffusion.sample(model, shape)
        else:
            return self.diffusion.p_sample_loop(model, shape)

    def training_losses(self, model, action, t = None):
        if self.config.use_continuous_action:
            return self.diffusion(model, action), 0
        else:
            t = (torch.randint(low=0, high=(self.num_timesteps), size=(action.shape[0],))).to(action.device)
            return self.diffusion.training_losses(model, action, t)

    def q_sample(self, 
                 x_start, 
                 t, 
                 noise = None,
                 ):
        if self.config.use_continuous_action:
            noise_shape = x_start.shape
            noise = torch.randn(size=noise_shape).to(self.device)
            return self.diffusion.q_sample(x_start, t, noise)
        else:
            noise_shape = x_start.shape + (self.diffusion.num_pixel_vals,)
            noise = torch.rand(size=noise_shape).to(self.device)
            return self.diffusion.q_sample(
                                x_start=x_start, 
                                t=t,
                                noise=noise,
                            )

    # def forward(self, hidden_states, noise, segment_lengths):
    #     """
    #     noise: (B, S, action_dim)
    #     """
    #     B = segment_lengths['batch_size']
    #     B_prime = segment_lengths['batch_size_prime']
    #     S = segment_lengths['timesteps']
    #     prompt_length = segment_lengths['prompt_length']

    #     # Remove prompt hidden states
    #     sa_hidden_states = hidden_states[:, prompt_length:, :].view(B_prime, S, -1, hidden_states.size(-1)) # (B_prime, S, state_len + action_len, hidden)
    #     # Remove state hidden states
    #     action_hidden_states = sa_hidden_states[:B, :, segment_lengths["state_length"]:, :] # (B, S, action_len, hidden)

    #     # Calculate action logits
    #     noise_logits = self.denoiser(action_hidden_states) # (B, S, action_len, max_range * 2)
    #     # Calculate loss
    #     loss_diffusion = F.cross_entropy(noise_logits, noise)

    #     return loss_diffusion
