import os
import torch
import torch.nn as nn
from sympy.physics.units import action

import utils.tensor_utils as TensorUtils
from algos.components.ResNet import FilmResNet
import utils.utils as utils

class Action_VAE(nn.Module):
    def __init__(self,
                 autoencoder,
                 action_dim,
                 action_chunk,
                 **kwargs
                 ):
        super().__init__(**kwargs)
        self.autoencoder = autoencoder  #
        self.action_dim = action_dim
        self.action_chunk = action_chunk

    def compute_loss(self, data):
        action = self.get_raw_action(data)
        loss, info = self.autoencoder(action)

        return loss, info

    def get_raw_action(self,data):
        action = data["cur_actions"]
        B, _ = action.shape
        action = action.reshape(B, self.action_chunk, self.action_dim)

        return action

    def sample_actions(self, data):
        pass

class MeanFlowScalePolicy(nn.Module):
    def __init__(
            self,
            autoencoder,
            flowar,
            image_encoder,
            latent_action_dim,
            latent_action_chunk,
            stage,
            action_dim,
            action_chunk,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.stage = stage
        self.autoencoder = autoencoder  # vae
        self.flow_model = flowar

        if stage == 1:
            self.autoencoder.requires_grad_(False)

        self.action_dim = action_dim
        self.action_chunk = action_chunk
        self.latent_action_dim = latent_action_dim
        self.latent_action_chunk = latent_action_chunk

        self.image_encoder = image_encoder
        self.vision_encoder = FilmResNet(image_dim=3, cond_dim=1024, backbone_name="resnet34")

    def compute_loss(self, data):
        if self.stage == 0:
            return self.compute_autoencoder_loss(data)
        elif self.stage == 1:
            return self.compute_flowscale_loss(data)

    def compute_autoencoder_loss(self, data):
        action = self.get_raw_action(data)
        loss, info = self.autoencoder(action)

        return loss, info

    def compute_flowscale_loss(self, data):
        action = self.get_raw_action(data)
        action_latent = self.autoencoder.get_sample(action) # [b,8,T]

        all_obs = self.get_obs(data)

        loss, loss_dict = self.flow_model(action_latent,all_obs)

        info = {
            'flow_loss': loss.item(),
        }

        info.update(loss_dict)

        return loss, info

    def get_raw_action(self,data):
        action = data["cur_actions"]
        B, _ = action.shape
        action = action.reshape(B, self.action_chunk, self.action_dim)

        return action

    def get_obs(self, data):
        cur_images = data["cur_images"]
        instruction = data["instruction"]
        cur_proprios = data["cur_proprios"]

        conditions = self.image_encoder.lang_proj(instruction)

        B, V, C, H, W = cur_images.shape
        cond = conditions.unsqueeze(1).repeat(1, V, 1).reshape(B * V, -1)
        vision_obs = cur_images.reshape(B * V, C, H, W)
        vision_semantics = self.vision_encoder(vision_obs, cond)
        vision_semantics = vision_semantics.reshape(B, -1)
        all_obs = torch.cat([vision_semantics, cur_proprios], dim=-1)

        return all_obs

    def generate(self, **data):
        self.eval()

        with torch.no_grad():
            cur_images = data["cur_images"]
            instruction = data["instruction"]
            cur_proprios = data["cur_proprios"]

            conditions = self.image_encoder.lang_proj(instruction)

            B, V, C, H, W = cur_images.shape
            cond = conditions.unsqueeze(1).repeat(1, V, 1).reshape(B * V, -1)
            vision_obs = cur_images.reshape(B * V, C, H, W)
            vision_semantics = self.vision_encoder(vision_obs, cond)
            vision_semantics = vision_semantics.reshape(B, -1)
            all_obs = torch.cat([vision_semantics, cur_proprios], dim=-1)

            codes = self.flow_model.sample_tokens(all_obs,training=False)
            codes = codes.reshape(B,self.latent_action_chunk,self.latent_action_dim)

            action = self.autoencoder.get_action(codes)

            action = action.reshape(B,-1)

            details = dict(actions=action)

        action, details = utils.process_outputs(action, **details)

        return action, details
