import argparse
import math
import os
import numpy as np
import torch
import torch.distributed as dist
#import torchvision
from ddp_utils import init_processes
from models import create_network
from pytorch_fid.fid_score import calculate_fid_given_paths
from sampler.karras_sample import karras_sample
from sampler.random_util import get_generator
from torch import nn
from torch.multiprocessing import Process
#from torchdiffeq import odeint_adjoint as odeint
from tqdm import tqdm
from PIL import Image
from diffusers.models import AutoencoderKL
def return_base_velocity(t):
    final = []
    for k in range(len(t)):
        final.extend([1., 2*t[k]**1, 3*t[k]**2, 4*t[k]**3, 5*t[k]**4, 6*t[k]**5, 7*t[k]**6, 8*t[k]**7]) 
    return torch.from_numpy(np.array(final, dtype="object").astype(np.float32))

def return_base_flow(t):
    final = []
    for k in range(len(t)):
        final.extend([t[k], t[k]**2, t[k]**3, t[k]**4, t[k]**5, t[k]**6, t[k]**7, t[k]**8])
    return torch.from_numpy(np.array(final, dtype="object").astype(np.float32))

def get_velocity_and_flow(model, x0, t, device, batch_size, polynom_order=8):
      polynom_degrees = torch.tensor([0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]).repeat(batch_size).float()
      pred = model( polynom_degrees.to(device) , torch.repeat_interleave(x0, polynom_order, 0).to(device), None)
      base_velocity = return_base_velocity(t)
      velocity = torch.mean((pred*base_velocity.unsqueeze(1).unsqueeze(1).unsqueeze(1).cuda()).view(batch_size, polynom_order, 4, 32, 32), 1)
      with torch.no_grad():
        base_flow = return_base_flow(t)
        flow = x0 + torch.mean((pred.detach()*base_flow.cuda()).view(batch_size, polynom_order, 4, 32, 32), 1)  
      return velocity, flow


if __name__ == "__main__":
    parser = argparse.ArgumentParser("FlowFit parameters")
    parser.add_argument("--use_origin_adm", action="store_true")
    parser.add_argument("--generator", type=str, default="determ", help="type of seed generator", choices=["dummy", "determ", "determ-indiv"],)
    parser.add_argument("--seed", type=int, default=100, help="seed used for initialization")
    parser.add_argument("--compute_fid", action="store_true", default=False, help="whether or not compute FID")
    parser.add_argument("--compute_nfe", action="store_true", default=False, help="whether or not compute NFE")
    parser.add_argument("--measure_time", action="store_true", default=False, help="wheter or not measure time")
    parser.add_argument("--epoch_id", type=int, default=1000)
    parser.add_argument("--n_sample", type=int, default=50000, help="number of sampled images")
    parser.add_argument("--model_type",type=str,default="DiT-B/2",help="model_type",choices=["adm", "ncsn++", "ddpm++", "DiT-B/2", "DiT-L/2", "DiT-XL/2"],)
    parser.add_argument("--image_size", type=int, default=256, help="size of image")
    parser.add_argument("--f", type=int, default=8, help="downsample rate of input image by the autoencoder")
    parser.add_argument("--scale_factor", type=float, default=0.18215, help="size of image")
    parser.add_argument("--num_in_channels", type=int, default=4, help="in channel image")
    parser.add_argument("--num_out_channels", type=int, default=4, help="in channel image")
    parser.add_argument("--nf", type=int, default=256, help="channel of image")
    parser.add_argument("--centered", action="store_false", default=True, help="-1,1 scale")
    parser.add_argument("--resamp_with_conv", type=bool, default=True)
    parser.add_argument("--num_res_blocks", type=int, default=2, help="number of resnet blocks per scale")
    parser.add_argument("--num_heads", type=int, default=4, help="number of head")
    parser.add_argument("--num_head_upsample", type=int, default=-1, help="number of head upsample")
    parser.add_argument("--num_head_channels", type=int, default=-1, help="number of head channels")
    parser.add_argument("--attn_resolutions", nargs="+", type=int, default=(16, 8, 4), help="resolution of applying attention")
    parser.add_argument("--ch_mult", nargs="+", type=int, default=(1, 2, 3, 4), help="channel mult") 
    parser.add_argument("--label_dim", type=int, default=0, help="label dimension, 0 if unconditional")
    parser.add_argument("--augment_dim", type=int, default=0, help="dimension of augmented label, 0 if not used")
    parser.add_argument("--dropout", type=float, default=0.0, help="drop-out rate")
    parser.add_argument("--num_classes", type=int, default=1, help="num classes")
    parser.add_argument("--label_dropout", type=float, default=0.0, help="Dropout probability of class labels for classifier-free guidance",)
    parser.add_argument("--cfg_scale", type=float, default=1.0, help="Scale for classifier-free guidance")
    parser.add_argument("--pretrained_autoencoder_ckpt", type=str, default="stabilityai/sd-vae-ft-mse")
    parser.add_argument("--output_log", type=str, default="")

    #######################################
    parser.add_argument("--num_steps", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=16, help="sample generating batch size")

    # sampling argument
    parser.add_argument("--use_karras_samplers", action="store_false", default=True)
    parser.add_argument("--method", type=str, default="euler", help="solver_method",)
    parser.add_argument("--step_size", type=float, default=0.01, help="step_size")
    parser.add_argument("--perturb", action="store_true", default=False)
  
    args = parser.parse_args()

    torch.set_grad_enabled(False)
    device = "cuda"
    to_range_0_1 = lambda x: (x + 1.0) / 2.0
    model = create_network(args).to(device)
    first_stage_model = AutoencoderKL.from_pretrained(args.pretrained_autoencoder_ckpt).to(device)
    first_stage_model.eval()
    ckpt = torch.load("flow_fit.pth",map_location=device)
    model.load_state_dict(ckpt, strict=True)
    model.eval()    
    x0 = torch.randn(1, 4, 32, 32).to(device) 
    _, x1 = get_velocity_and_flow(model, x0, [1.], device, 1)
         

    fake_image = first_stage_model.decode(x1 / (args.scale_factor)).sample
    fake_image = torch.clamp(to_range_0_1(fake_image), 0, 1)
    
    Image.fromarray((fake_image.squeeze().permute(1, 2, 0).cpu().numpy()*255.).astype(np.uint8)).save("x1.png")


