#-*- coding:utf-8 -*-

from diffusion_policy.vision_encoder import MultiImageObsEncoder
from diffusion.model import ConditionalUnet1D, get_resnet
from torchvision import models
from enum import Enum
from typing import Union
import torch 

class ModelType(Enum):
    CNN = "C"
    TRANSFORMER = "T"

class MultiModelType(Enum):
    CNN = "C"
    MINGPT = "minGPT"
    DiT_S = "DiT_S"
    DiT_B = "DiT_B"
    DiT_L = "DiT_L"
    DiT_XL = "DiT_XL"

def rand_log_normal(shape, loc=0., scale=1., device='cuda', dtype=torch.float32):
    """Draws samples from an lognormal distribution."""
    return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()

class ConditionalKarrasUnet1D(ConditionalUnet1D):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, 
            sample: torch.Tensor, 
            sigmas: Union[torch.Tensor, float, int], 
            global_cond=None):
        """
            x: (B,T,input_dim)
            sigmas: (B,) or int, diffusion step
            global_cond: (B,global_cond_dim)
            output: (B,T,input_dim)`
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. sigmas
        if not torch.is_tensor(sigmas):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            sigmas = torch.tensor([sigmas], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(sigmas) and len(sigmas.shape) == 0:
            sigmas = sigmas[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        sigmas = sigmas.expand(sample.shape[0])

        c_noises = sigmas.log() / 4
        global_feature = self.diffusion_step_encoder(c_noises)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)
        
        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x

class ImageConditionalKarrasUnet1D(ConditionalUnet1D): # Condition : Image Only
    def __init__(self, input_width:int = 96, input_height:int = 96, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.image_encoder = MultiImageObsEncoder(
            shape_meta={
                'obs':{
                    'image':{
                        'shape':[3, input_height, input_width],
                        'type':'rgb'
                    },
                    # 'agent_pos':{
                    #     'shape':[2],
                    #     'type':'low_dim'
                    # }, -> Add this if statement data required
                },
                'action':{
                    'shape':[2]   
                }
            },
            rgb_model=get_resnet(),
            crop_shape=None
        )
        self.input_width = input_width 
        self.input_height = input_height

    def forward(self, 
            sample: torch.Tensor, 
            sigmas: Union[torch.Tensor, float, int], 
            global_cond=None):
        """
            x: (B,T,input_dim)
            sigmas: (B,) or int, diffusion step
            global_cond: (B, T, 3, 96, 96)
            output: (B,T,input_dim)`
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        B, T, _ = sample.shape

        # 1. sigmas
        if not torch.is_tensor(sigmas):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            sigmas = torch.tensor([sigmas], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(sigmas) and len(sigmas.shape) == 0:
            sigmas = sigmas[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        sigmas = sigmas.expand(sample.shape[0])

        c_noises = sigmas.log() / 4
        global_feature = self.diffusion_step_encoder(c_noises) # (B, 256)

        if global_cond is not None:
            obs_dict = {'image': global_cond.reshape(B*T, -1, self.input_height, self.input_width)}
            image_cond = self.image_encoder(
                obs_dict
            )
            image_cond = image_cond.reshape(B, -1) # (B, T*512)
            # global_cond = global_cond.reshape(B*T, -1, self.input_height, self.input_width)
            # image_cond = self.image_encoder(global_cond) # (B*T, 512)
            # image_cond = image_cond.reshape(B, -1) # (B, T*512)
            global_feature = torch.cat([
                global_feature, image_cond
            ], axis=-1)
        
        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x


if __name__ == '__main__':
    obs_dim = 5
    action_dim = 2
    
    pred_horizon = 16
    obs_horizon = 2
    action_horizon = 8

    # create network object
    noise_pred_net = ConditionalKarrasUnet1D(
        input_dim=action_dim,
        global_cond_dim=obs_dim*obs_horizon
    )

    # example inputs
    noised_action = torch.randn((1, pred_horizon, action_dim))
    obs = torch.zeros((1, obs_horizon, obs_dim))
    diffusion_iter = torch.ones((1,))

    # the noise prediction network
    # takes noisy action, diffusion iteration and observation as input
    # predicts the noise added to action
    noise = noise_pred_net(
        sample=noised_action, 
        sigmas=diffusion_iter,
        global_cond=obs.flatten(start_dim=1))

    # illustration of removing noise 
    # the actual noise removal is performed by NoiseScheduler 
    # and is dependent on the diffusion noise schedule
    denoised_action = noised_action - noise
    print(denoised_action.shape)

    noise_pred_net2 = ImageConditionalKarrasUnet1D(
        input_dim=action_dim,
        global_cond_dim=512*2#obs_dim*obs_horizon
    )

    noised_action = torch.randn((1, pred_horizon, action_dim))
    obs = torch.zeros((1, 2, 3, 96, 96))
    diffusion_iter = torch.ones((1,))

    noise2 = noise_pred_net2(
        sample=noised_action, 
        sigmas=diffusion_iter,
        global_cond=obs)
    denoised_action = noised_action - noise
    print(denoised_action.shape)