# use diffusion model to generate pseudo ground truth flow volume based on RegionMM
# 3D noise to 3D flow
# flow size: 2*32*32*40
# some codes based on https://github.com/lucidrains/video-diffusion-pytorch

import os
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
from DM.modules.DDPM import Unet3D, GaussianDiffusion


class FlowDiffusion(nn.Module):
    def __init__(self, model_params, data_params, pde_params, is_train=True):
        super(FlowDiffusion, self).__init__()
        self.model_params = model_params
        self.data_params = data_params
        self.use_residual_flow = self.model_params["use_residual_flow"]

        self.unet = Unet3D(model_params['unet_dim'],
                           channels=model_params['max_features']*2,
                           out_grid_dim=model_params['max_features'],
                           dim_mults=self.model_params['dim_mults'],
                           use_bert_text_cond=False,
                           use_final_activation=False,
                           use_deconv=self.model_params["use_deconv"],
                           padding_mode=self.model_params['padding_mode'])

        self.diffusion = GaussianDiffusion(
            self.unet,
            image_size=data_params['field_size'],
            num_frames=data_params['num_frames'],
            channels=data_params['num_channels'],
            sampling_timesteps=self.model_params['sampling_timesteps'],
            timesteps=self.model_params['timesteps'],  # number of steps
            loss_type='l2',  # L1 or L2
            use_dynamic_thres=True,
            null_cond_prob=self.model_params['null_cond_prob'],
            ddim_sampling_eta=self.model_params['ddim_sampling_eta'],
            pde_params=pde_params
        )
        # training
        self.is_train = is_train
        if self.is_train:
            self.unet.train()
            self.diffusion.train()
            self.loss = torch.tensor(0.0)
            '''cuda is here'''
            #self.loss = torch.tensor(0.0).cuda()
            

    def forward(self, f, u0, z0, autoencoder, normalizer):
        # compute pseudo ground-truth flow
        # z : [batch_size, channels, frames, H, W]
        if self.is_train:
            self.loss = self.diffusion(f, u0, z0, autoencoder, normalizer)


    def sample_one_video(self, z0, u0, autoencoder, normalizer, shape, cond_scale=1):
        '''waiting for LFAE'''
        # if cond_scale = 1.0, not using unconditional model
        pred = self.diffusion.sample(z0, u0, autoencoder, normalizer, shape, cond_scale=cond_scale)
        z0 = z0.repeat(1, 1, self.data_params['num_frames']-1, 1, 1)
        with torch.no_grad():
            u_pred = autoencoder.decode_from_flow(z0, pred)
        u_pred = torch.concat([u0, u_pred], dim=2)
        return u_pred

    def set_train_input(self, ref_img, real_vid):
        self.ref_img = ref_img
        self.real_vid = real_vid

    def set_sample_input(self, sample_img, sample_text):
        self.sample_img = sample_img
        self.sample_text = sample_text

    def print_learning_rate(self):
        lr = self.optimizer_diff.param_groups[0]['lr']
        assert lr > 0
        print('lr= %.7f' % lr)

    def get_grid(self, b, nf, H, W, normalize=True):
        if normalize:
            h_range = torch.linspace(-1, 1, H)
            w_range = torch.linspace(-1, 1, W)
        else:
            h_range = torch.arange(0, H)
            w_range = torch.arange(0, W)
        grid = torch.stack(torch.meshgrid([h_range, w_range]), -1).repeat(b, 1, 1, 1).flip(3).float()  # flip h,w to x,y
        return grid.permute(0, 3, 1, 2).unsqueeze(dim=2).repeat(1, 1, nf, 1, 1)

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
    

def convert_yaml_data(data):
    if isinstance(data, str):
        if data.isdigit():
            return int(data)
        try:
            return float(data)
        except ValueError:
            return data
    elif isinstance(data, list):
        return [convert_yaml_data(item) for item in data]
    elif isinstance(data, dict):
        return {key: convert_yaml_data(value) for key, value in data.items()}
    else:
        return data


if __name__ == "__main__":
    print(os.getcwd())
    #os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    bs = 1
    img_size = 64
    num_frames = 40
    ref_text = ["play basketball"] * bs
    ref_img = torch.rand((bs, 3, img_size, img_size), dtype=torch.float32)
    real_vid = torch.rand((bs, 3, num_frames, img_size, img_size), dtype=torch.float32)
    '''model = FlowDiffusion(use_residual_flow=False,
                          sampling_timesteps=10,
                          img_size=16,
                          config_pth="/workspace/code/CVPR23_LFDM/config/mug128.yaml",
                          pretrained_pth="")'''
    file_path = "./test_process.yaml"
    with open(file_path, 'r') as yaml_file:
        yaml_data = yaml.load(yaml_file, Loader=yaml.FullLoader)["test_model_params"]
    yaml_data = convert_yaml_data(yaml_data)

    model = FlowDiffusion(model_params=yaml_data)

    '''cuda is here'''
    '''the others are in 69,118 line'''
    #model.cuda()
    # model.train()
    # model.set_train_input(ref_img=ref_img, real_vid=real_vid, ref_text=ref_text)
    # model.optimize_parameters()
    model.eval()
    model.set_sample_input(sample_img=ref_img, sample_text=ref_text)
    #model.sample_one_video(cond_scale=1.0)
    model(real_vid,ref_img)



