import os
import torch
import time
from sympy.integrals.heurisch import components
import torch.nn as nn

from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply

class MeanFlowScalePolicy(BaseImagePolicy):
    def __init__(
            self,
            autoencoder,
            flowar,
            image_encoder,
            latent_action_dim,
            latent_action_chunk,
            stage,
            action_dim,
            action_chunk,
            n_action_steps,
            n_obs_steps,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.stage = stage
        self.autoencoder = autoencoder  # vae
        self.flow_model = flowar

        if stage == 1:
            self.autoencoder.requires_grad_(False)

        self.action_dim = action_dim
        self.action_chunk = action_chunk
        self.n_action_steps = n_action_steps
        self.n_obs_steps = n_obs_steps

        self.latent_action_dim = latent_action_dim
        self.latent_action_chunk = latent_action_chunk

        self.image_encoder = image_encoder
        obs_feature_dim = image_encoder.output_shape()[-1]
        self.flow_model.obs_in_features = obs_feature_dim * n_obs_steps

        self.normalizer = LinearNormalizer()

    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def compute_loss(self, data):
        if self.stage == 0:
            return self.compute_autoencoder_loss(data)
        else:
            return self.compute_flowscale_loss(data)

    def compute_autoencoder_loss(self, data):
        action = self.get_raw_action(data)
        loss, info = self.autoencoder(action)

        return loss, info

    def compute_flowscale_loss(self, data):

        action = self.get_raw_action(data)
        action_latent = self.autoencoder.get_sample(action) # [b,8,T]

        all_obs = self.get_obs(data)

        loss, loss_dict = self.flow_model(action_latent,all_obs)

        info = {
            'flow_loss': loss.item(),
        }

        info.update(loss_dict)

        return loss, info

    def get_raw_action(self, batch):
        action = self.normalizer["action"].normalize(batch["action"])

        return action

    def get_obs(self, data):
        nobs = self.normalizer.normalize(data["obs"])
        nactions = self.normalizer['action'].normalize(data['action'])
        batch_size = nactions.shape[0]

        this_nobs = dict_apply(nobs,
                               lambda x: x[:, :self.n_obs_steps, ...].reshape(-1, *x.shape[2:]))
        nobs_features = self.image_encoder(this_nobs)
        global_cond = nobs_features.reshape(batch_size, -1)

        return global_cond

    def predict_action(self, obs_dict):
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        Da = self.action_dim
        To = self.n_obs_steps

        this_nobs = dict_apply(nobs, lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]))
        nobs_features = self.image_encoder(this_nobs)

        global_cond = nobs_features.reshape(B, -1)

        codes = self.flow_model.sample_tokens(global_cond,training=False)
        codes = codes.reshape(B,self.latent_action_chunk,self.latent_action_dim)

        action = self.autoencoder.get_action(codes)

        naction_pred = action[..., :Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        start = To - 1
        end = start + self.n_action_steps
        action = action_pred[:, start:end]

        result = {
            'action': action,
            'action_pred': action_pred
        }
        return result
