
import torch
import torch.nn as nn

from networks import Encoder, Decoder

import json

from torch.nn import DataParallel as DP

from omegaconf import OmegaConf
from latent_diffusion_main.ldm.util import instantiate_from_config

def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt)#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.eval()
    return model

def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    out = torch.gather(v, index=t, dim=0).float()
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.modelencoder = Encoder(config)
        self.modeldecoder = Decoder(config)

        self.img_size = img_size = config['image_shape'][-1]

        #diffusion
        self.T = config['T']

        self.register_buffer(
            'betas', torch.linspace(config['beta_1'], config['beta_T'], config['T']).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)

            # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

        self.mean_type=config['mean_type']

        VQGAN_config = OmegaConf.load(config['VQGAN_config'])  
        model = load_model_from_config(VQGAN_config, config['VQGAN'])
        self.VQGAN=model
        self.linear_mean=0
        self.linear_scale=0.18215
        self.register_buffer('X_T',torch.randn(1, 4, self.img_size//8, self.img_size//8))
        self.summ_image_count = config['summ_image_count']

    def forward(self, data, phase_param, requires_results=False,certain_t=None):
        num_slots = phase_param['num_slots']
        data = {key: val.cuda(non_blocking=True) for key, val in data.items() if key in {'image','segment'}}
        images = data['image']
        segments = data['segment'][:, None, None].long()
        # layers = data['layers'].long()
        scatter_shape = [segments.shape[0], 24, *segments.shape[2:]]
        segments = torch.zeros(scatter_shape, device=segments.device).scatter_(1, segments, 1)
        if requires_results:
            images=images[:self.summ_image_count]
            segments=segments[:self.summ_image_count]
        results_raw = {'image': images, 'segment': segments}
        x = images
        batch_size, C, H, W = x.shape #B 3 H W
        
        infer_output = self.modelencoder(x, num_slots)

        # VQGAN encode
        with torch.no_grad():
            x=self.VQGAN.encode(x).sample()               # B, d_VQGAN, H_diff,W_diff 
            B,d_VQGAM,H_diff,W_diff=x.shape

        
        input_image=(x-self.linear_mean)*self.linear_scale# to [-1,1] # B 3 H W

        noise = torch.randn_like(input_image)# B 3 H W

        if certain_t!=None:
            t = torch.randint(low=certain_t,high=certain_t+1, size=(input_image.shape[0],), device=input_image.device)
        else:
            t = torch.randint(self.T, size=(input_image.shape[0],), device=input_image.device)

        input_image = (
            extract(self.sqrt_alphas_bar, t, input_image.shape) * input_image +
            extract(self.sqrt_one_minus_alphas_bar, t, input_image.shape) * noise) # B 3 H W


        infer_output.update({'t':t,
                  'noised_image':input_image,
                  'noise':noise,})


        results = self.modeldecoder(infer_output)

        if requires_results:                

            X_T=self.X_T
            output = self.modeldecoder.sampler(X_T.expand(batch_size,-1,-1,-1),infer_output['obj_latents'])/self.linear_scale+self.linear_mean
            self.VQGAN.eval()
            output=self.VQGAN.decode(output).clamp(0, 1)        
            masked_apc_infer = infer_output['image'] * infer_output['infer_mask'] #B 1+N 3 H W
            masked_apc=infer_output['image'] * results['mask'] #B 1+N 3 H W
            recon=output
            apc=output.unsqueeze(1).expand(-1,num_slots+1,-1,-1,-1)
            results.update({'recon': recon,'infer_m_x':masked_apc_infer,'apc_all':apc,'m_x': masked_apc})

        results_raw.update(results)

        ins_seg = self.colour_seg_masks(results_raw['mask'])
        results_raw.update({'ins_seg': ins_seg})
        ins_seg_gt = self.colour_seg_masks(results_raw['segment'])
        results_raw.update({'ins_seg_gt': ins_seg_gt})
        infer_ins_seg = self.colour_seg_masks(results_raw['infer_mask'])
        results_raw.update({'infer_ins_seg':infer_ins_seg})


        losses = self.compute_losses(results_raw)


        with torch.no_grad():
            infer_mask = results_raw['infer_mask']
            infer_segment_all = torch.argmax(infer_mask, dim=1, keepdim=True)
            infer_segment_obj = torch.argmax(infer_mask[:, 1:], dim=1, keepdim=True)
            infer_mask_oh_all = torch.zeros_like(infer_mask).scatter_(1, infer_segment_all, 1)
            infer_mask_oh_obj = torch.zeros_like(infer_mask[:, 1:]).scatter_(1, infer_segment_obj, 1)
            
            if requires_results:
                mask = results_raw['mask']
                segment_all = torch.argmax(mask, dim=1, keepdim=True)
                segment_obj = torch.argmax(mask[:, 1:], dim=1, keepdim=True)
                mask_oh_all = torch.zeros_like(mask).scatter_(1, segment_all, 1)
                mask_oh_obj = torch.zeros_like(mask[:, 1:]).scatter_(1, segment_obj, 1)
                recon = results_raw['recon']
            else:
                mask = results_raw['mask']
                segment_all = torch.argmax(mask, dim=1, keepdim=True)
                segment_obj = torch.argmax(mask[:, 1:], dim=1, keepdim=True)
                mask_oh_all = torch.zeros_like(mask).scatter_(1, segment_all, 1)
                mask_oh_obj = torch.zeros_like(mask[:, 1:]).scatter_(1, segment_obj, 1)
                recon = results_raw['predict_noise']
                images=results_raw['noise']
        return results_raw, losses

    def compute_losses(self, results,eps=1e-5):
        noise = results['noise']
        predict_noise = results['predict_noise']
        losses = {'nll': ((noise-predict_noise)**2).sum()/noise.shape[0]}
        return losses

    @staticmethod
    def colour_seg_masks(masks, palette='15'):
        ins_seg = torch.argmax(masks.squeeze(2), 1, True)
        colours = json.load(open(f'colour_palette15.json'))
        img_r = torch.zeros_like(ins_seg)
        img_g = torch.zeros_like(ins_seg)
        img_b = torch.zeros_like(ins_seg)
        for c_idx in range(ins_seg.max().item() + 1):
            c_map = ins_seg == c_idx
            if c_map.any():
                img_r[c_map] = colours['palette'][c_idx%int(palette)][0]
                img_g[c_map] = colours['palette'][c_idx%int(palette)][1]
                img_b[c_map] = colours['palette'][c_idx%int(palette)][2]
        return torch.cat([img_r, img_g, img_b], dim=1)

def get_model(config):
    net = Model(config).cuda()
    if config['use_dp']:
        net = DP(net)
    return net
