import torch
import torch.nn as nn
from config.base_config import Config
from modules.transformer import Transformer
from modules.transformer_stochastic import StochasticNomean, StochasticGuideRailMean
from dpm_model.dpm_txt_trunc import DPM

class DiffusionAlign(nn.Module):
    def __init__(self, config: Config):
        super(DiffusionAlign, self).__init__()

        self.config = config

        if self.config.apply_ldm:
            raise NotImplementedError
        else:
            self.dm_txt = DPM(self.config, dpm_arch='DiT_txt_trunc')

        self.dm_txt.set_new_noise_schedule()
        self.dm_txt.set_loss()

    def forward(self, text_features, video_features_pooled=None, no_aligned_embed=True):

        bs, dim = text_features.shape
        # text_features_DMalign = torch.zeros_like(text_features)
        dm_loss = 0.

        if no_aligned_embed:

            assert video_features_pooled is not None
            prior_c = text_features
            video_features_pooled_diag = torch.transpose(torch.diagonal(video_features_pooled, dim1=0, dim2=1), 0, 1)


            dm_loss = dm_loss + self.dm_txt.naive_loss(text_feature=prior_c, vid_feature=video_features_pooled_diag, noise=None)

            # dm_loss = dm_loss / bs
            return  dm_loss

        else:
            prior_c = text_features
            # print(f'>>>[DiffusionAlign] prior_c.shape={prior_c.shape}') # [bs, dim]

            dm_sample = self.dm_txt.p_sample_loop(x_txt=prior_c)
            text_features_DMalign = dm_sample.squeeze(1)

            return text_features_DMalign


class CLIPTransformer_txt_trunc_dm(nn.Module):
    def __init__(self, config: Config):
        super(CLIPTransformer_txt_trunc_dm, self).__init__()
        self.config = config
        
        if self.config.huggingface:
            from transformers import CLIPModel
            if config.clip_arch == 'ViT-B/32':
                self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
            elif config.clip_arch == 'ViT-B/16':
                self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
            else:
                raise ValueError
        else:
            from model.clip_model import load_clip
            self.clip = load_clip(config.clip_arch)

        config.pooling_type = 'transformer'
        self.pool_frames = Transformer(config)

        self.text_cond_processor = DiffusionAlign(config)


    def forward(self, data, return_all_frames=False, no_aligned_embed=True):
        batch_size = data['video'].shape[0]
        text_data = data['text']
        video_data = data['video']
        video_data = video_data.reshape(-1, 3, self.config.input_res, self.config.input_res)

        if no_aligned_embed:

            if self.config.huggingface:
                text_features = self.clip.get_text_features(**text_data)
                video_features = self.clip.get_image_features(video_data)
            else:
                text_features = self.clip.encode_text(text_data)
                video_features = self.clip.encode_image(video_data)


            video_features = video_features.reshape(batch_size, self.config.num_frames, -1) # [bs, 12, 512]

            video_features_pooled = self.pool_frames(text_features, video_features)

            dm_loss = self.text_cond_processor(text_features, video_features_pooled, no_aligned_embed=no_aligned_embed)

            return text_features, video_features, video_features_pooled, text_features,  dm_loss


        else:

            if self.config.huggingface:
                text_features = self.clip.get_text_features(**text_data)
                video_features = self.clip.get_image_features(video_data)

            else:
                text_features = self.clip.encode_text(text_data)
                video_features = self.clip.encode_image(video_data)

            video_features = video_features.reshape(batch_size, self.config.num_frames, -1)
            video_features_pooled = self.pool_frames(text_features, video_features)


            text_features_DMalign = self.text_cond_processor(text_features, video_features_pooled, no_aligned_embed=no_aligned_embed)

            return text_features, video_features, video_features_pooled, text_features_DMalign

