import torch
import torch.nn as nn
from libs.caption_decoder import CaptionDecoder
import libs.autoencoder
from libs.clip import FrozenCLIPEmbedder
import clip
from libs.diffusion_schedule import stable_diffusion_beta_schedule, Schedule
import random
from torch.nn.parallel import DistributedDataParallel as DDP
from absl import logging
from typing import Optional, List, Dict
import utils

def mos(a, start_dim=1):  # mean of square
    return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
        
class FeedModel(nn.Module):
    def __init__(self, device, config):
        super().__init__()
        self.device = device
        self.config = config
        # frozen caption decoder, auto encoder and text encoder
        self.caption_decoder = CaptionDecoder(device=device, **config.caption_decoder)
        self.autoencoder = libs.autoencoder.get_model(**config.autoencoder).to(device)
        self.clip_text_model = FrozenCLIPEmbedder(device=device)
        self.clip_img_model, clip_img_model_preprocess = clip.load(config.clip_img_model, jit=False)
        self.clip_img_model.to(device).eval().requires_grad_(False)
        
        self.train_state = utils.initialize_train_state(config, device)
        
        _betas = stable_diffusion_beta_schedule()
        self.schedule = Schedule(_betas)
        logging.info(f'SCHEDULE: use {self.schedule}')
        
    def load_feed_weight(self, feed_resume_path):
        self.train_state.feed_model.load_state_dict(torch.load(feed_resume_path))
        
    def train_mode(self):
        self.train_state.nnet.train()
        self.train_state.feed_model.train()
        
    def eval_mode(self):
        self.train_state.nnet.eval()
        self.train_state.feed_model.eval()

    def compute_loss(self, batch):
        img = batch["image1"].to(self.device)
        text = batch["caption"]
        img4clip = batch["img4clip"].to(self.device)
        detect = batch["detect"].to(self.device)
        mask = batch["mask"].to(self.device)
        data_type = batch["data_type"].to(self.device)
        
        with torch.no_grad():
            img = self.autoencoder.encode(img)
            clip_img = self.clip_img_model.encode_image(img4clip).unsqueeze(1)
            text = self.clip_text_model.encode(text)
            
        if random.random() < self.config.cfg_p:
            detect = torch.zeros_like(detect)
            
        self.train_state.feed_model(detect)
        text = self.caption_decoder.encode_prefix(text)
                
        # add noise
        n, (img_eps, clip_img_eps), (img_n, clip_img_n) = self.schedule.sample([img, clip_img])  # n in {1, ..., 1000}
        n = n.to(self.device)
        dict_out = self.train_state.nnet(img_n, clip_img_n, text, t_img=n, t_text=torch.zeros_like(n, device=self.device), data_type=data_type)
        
        
        if isinstance(self.train_state.feed_model, DDP):
            origin_multiplier = self.train_state.feed_model.module.multiplier
            self.train_state.feed_model.module.set_multiplier(0.)
        else:
            origin_multiplier = 1.0
            self.train_state.feed_model.set_multiplier(0.)
            
        with torch.no_grad():
            origin_dict_out = self.train_state.nnet(img_n, clip_img_n, text, t_img=n, t_text=torch.zeros_like(n, device=self.device), data_type=data_type)
        if isinstance(self.train_state.feed_model, DDP):
            self.train_state.feed_model.module.set_multiplier(origin_multiplier)
        else:
            self.train_state.feed_model.set_multiplier(origin_multiplier)
            
        img_out, clip_img_out, text_out = dict_out["img_out"], dict_out["clip_img_out"], dict_out["text_out"]
        origin_img_out = origin_dict_out["img_out"]
        diff = img_eps - img_out
        if random.random() < self.config.mask_p:
            loss_img = mos(diff * mask) + mos((origin_img_out - img_out) * (1 - mask)) # make no mask area sim to origin nnet
        else:
            loss_img = mos(diff)
        loss_clip_img = mos(clip_img_eps - clip_img_out)
        loss = loss_img + 0. * loss_clip_img + 0. * mos(text_out)
        return loss, loss_img, loss_clip_img
    
    def gen_one_function(self, data_type:Optional[str], caption:str, source_group:List[Dict], class_word="", 
                    cond_multiplier=1.,
                    uncond_multiplier=0.,
                    cfg=None):
        sample_fn = utils.get_sample_fn(**self.config.feed_model)
        eval_feed_model = self.train_state.feed_model.module if isinstance(self.train_state.feed_model, DDP) else self.train_state.feed_model
        eval_nnet = self.train_state.nnet.module if isinstance(self.train_state.nnet, DDP) else self.train_state.nnet
        dic = source_group[0]
        if cfg is None:
            cfg = self.config.sample.scale
            
        return sample_fn(
                prompt=caption,
                image2=dic["path"],
                detect_path=dic["detect_path"],
                detect_mask_path=dic["detect_mask_path"],
                mask_path=dic["mask_path"],
                config=self.config,
                nnet=eval_nnet,
                caption_decoder=self.caption_decoder,
                clip_text_model=self.clip_text_model,
                feed_model=eval_feed_model,
                autoencoder=self.autoencoder,
                device=self.device,
                cond_multiplier=cond_multiplier,
                uncond_multiplier=uncond_multiplier,
                cfg=cfg)["samples"]
        
    def gen_one(self, prompt, ref_img, cond_multiplier=1., uncond_multiplier=0., cfg= 5.):
        sample_fn = utils.get_sample_fn(**self.config.feed_model)
        eval_feed_model = self.train_state.feed_model.module if isinstance(self.train_state.feed_model, DDP) else self.train_state.feed_model
        eval_nnet = self.train_state.nnet.module if isinstance(self.train_state.nnet, DDP) else self.train_state.nnet
        
        generation = sample_fn(
                prompt=prompt,
                image2=None,
                detect_path=ref_img,
                detect_mask_path=None,
                config=self.config,
                nnet=eval_nnet,
                caption_decoder=self.caption_decoder,
                clip_text_model=self.clip_text_model,
                feed_model=eval_feed_model,
                autoencoder=self.autoencoder,
                device=self.device,
                cond_multiplier=cond_multiplier,
                uncond_multiplier=uncond_multiplier,
                cfg=cfg)["samples"]
        
        return generation