
import argparse
import os
import shutil
from functools import partial
from time import time
from PIL import Image
import torch
from omegaconf import OmegaConf
from tqdm import tqdm
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from accelerate.utils import set_seed
from datasets_prep import get_dataset
from EMA import EMA
from models import create_network
from torchdiffeq import odeint_adjoint as odeint
from torch.utils.data import Dataset
import copy
from copy import deepcopy
from collections import OrderedDict
# faster training
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import numpy as np
from torch_ema import ExponentialMovingAverage

class DatasetEncoded(Dataset):
    def __init__(self, dataset_path) :
  
        self.data = torch.load(dataset_path)
       

    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):

      
        return self.data[idx], [0]




def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag

# %%
def train(args):
    from diffusers.models import AutoencoderKL

    assert torch.cuda.is_available(), "Training currently requires at least one GPU."

    # Setup accelerator:
    device = "cuda"
    dtype = torch.float32
    batch_size = args.batch_size

    dataset = DatasetEncoded("encoded_celeba.pt") 
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True,
    )

    model = create_network(args).to(device, dtype=dtype)
    ema_model = deepcopy(model).to(device) 
    requires_grad(ema_model, False)
    ema = ExponentialMovingAverage(model.parameters(), decay=0.999)

    velocity_model = create_network(args).to(device, dtype=dtype)
    ema_velocity = ExponentialMovingAverage(velocity_model.parameters(), decay=0.999)


    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.)
    optimizer_velocity = optim.AdamW(velocity_model.parameters(), lr=args.lr, weight_decay=0.)


    exp = args.exp
    parent_dir = "./saved_info/latent_flow/{}".format(args.dataset)

    exp_path = os.path.join(parent_dir, exp)
    
    if not os.path.exists(exp_path):
            os.makedirs(exp_path)
            config_dict = vars(args)
            OmegaConf.save(config_dict, os.path.join(exp_path, "config.yaml"))
   
    global_step, epoch, init_epoch = 0, 0, 0

    use_label = args.num_classes > 1
    dt = args.dt
    model.train()  
    for epoch in range(init_epoch, args.num_epoch + 1):
        pbar = tqdm(data_loader) 
        cum_loss = 0.
        index = 0.
        for iteration, (x, y) in enumerate(pbar):
          
            velocity_model.zero_grad()

            z_1 = x.to(device, dtype=dtype, non_blocking=True)
            y = None if not use_label else y.to(device, non_blocking=True)
            z_0 = torch.randn(args.batch_size, 4, 32, 32).to(device)
            t = torch.rand((z_1.size(0),), dtype=dtype, device=device)
            t = t.view(-1, 1, 1, 1)
            z_t = (1 - t) * z_0 + (1e-5 + (1 - 1e-5) * t) * z_1
            u = (1 - 1e-5) * z_1 - z_0
            v = velocity_model(t.squeeze(), z_t, y)
            loss = F.l1_loss(v, u)
            loss.backward()
            optimizer_velocity.step()
            ema_velocity.update()

            model.zero_grad()

            z_0 = torch.randn(args.batch_size, 4, 32, 32).to(device)
            t = dt + (1 - 2*dt)*torch.rand((args.batch_size,), dtype=dtype, device=device)
            v_t_shortcut = model(t.squeeze() + dt, z_0, y)
            with torch.no_grad():
                v_t_shortcut_past = model(t.squeeze(), z_0, y)
                v = velocity_model(t.squeeze(), z_0 + t.view(-1, 1, 1, 1) * v_t_shortcut_past, y)
            loss = F.l1_loss(v_t_shortcut, ((t-dt) / t).view(-1, 1, 1, 1) * v_t_shortcut_past + (dt/t).view(-1, 1, 1, 1) *v)
            loss.backward()
            optimizer.step()
            ema.update()
            global_step += 1
            cum_loss += loss.item()
            index += 1
            pbar.set_description(f"loss: {round(cum_loss / index, 5)} ")
              

           
    ema.copy_to(ema_model.parameters())
    torch.save(ema_model.state_dict(), os.path.join(exp_path, "model_ema.pth"),)
        

# %%
if __name__ == "__main__":
    parser = argparse.ArgumentParser("direct models parameters")
    parser.add_argument("--seed", type=int, default=1024, help="seed used for initialization")

    parser.add_argument("--resume", action="store_true", default=False)
    parser.add_argument("--model_ckpt", type=str, default=None, help="Model ckpt to init from")

    parser.add_argument(
        "--model_type",
        type=str,
        default="DiT-B/2",
        help="model_type",
        choices=[
            "adm",
            "ncsn++",
            "ddpm++",
            "DiT-B/2",
            "DiT-L/2",
            "DiT-L/4",
            "DiT-XL/2",
        ],
    )
    parser.add_argument("--image_size", type=int, default=32, help="size of image")
    parser.add_argument(
        "--f",
        type=int,
        default=8,
        help="downsample rate of input image by the autoencoder",
    )
    parser.add_argument("--scale_factor", type=float, default=0.18215, help="size of image")
    parser.add_argument("--num_in_channels", type=int, default=4, help="in channel image")
    parser.add_argument("--num_out_channels", type=int, default=4, help="in channel image")
    parser.add_argument("--nf", type=int, default=256, help="channel of model")
    parser.add_argument(
        "--num_res_blocks",
        type=int,
        default=2,
        help="number of resnet blocks per scale",
    )
    parser.add_argument(
        "--attn_resolutions",
        nargs="+",
        type=int,
        default=(16,),
        help="resolution of applying attention",
    )
    parser.add_argument(
        "--ch_mult",
        nargs="+",
        type=int,
        default=(1, 1, 2, 2, 4, 4),
        help="channel mult",
    )
    parser.add_argument("--dropout", type=float, default=0.0, help="drop-out rate")
    parser.add_argument("--label_dim", type=int, default=0, help="label dimension, 0 if unconditional")
    parser.add_argument(
        "--augment_dim",
        type=int,
        default=0,
        help="dimension of augmented label, 0 if not used",
    )
    parser.add_argument("--num_classes", type=int, default=1, help="num classes")
    parser.add_argument(
        "--label_dropout",
        type=float,
        default=0.0,
        help="Dropout probability of class labels for classifier-free guidance",
    )

    # Original ADM
    parser.add_argument("--layout", action="store_true")
    parser.add_argument("--use_origin_adm", action="store_true")
    parser.add_argument("--use_scale_shift_norm", type=bool, default=True)
    parser.add_argument("--resblock_updown", type=bool, default=False)
    parser.add_argument("--use_new_attention_order", type=bool, default=False)
    parser.add_argument("--centered", action="store_false", default=True, help="-1,1 scale")
    parser.add_argument("--resamp_with_conv", type=bool, default=True)
    parser.add_argument("--num_heads", type=int, default=4, help="number of head")
    parser.add_argument("--num_head_upsample", type=int, default=-1, help="number of head upsample")
    parser.add_argument("--num_head_channels", type=int, default=-1, help="number of head channels")

    parser.add_argument("--pretrained_autoencoder_ckpt", type=str, default="stabilityai/sd-vae-ft-mse")

    # training
    parser.add_argument("--exp", default="experiment_celeba_default", help="name of experiment")
    parser.add_argument("--dataset", default="celeba", help="name of dataset")
    parser.add_argument("--datadir", default="./data")
    parser.add_argument("--num_timesteps", type=int, default=100)
    parser.add_argument(
        "--use_grad_checkpointing",
        action="store_true",
        default=False,
        help="Enable gradient checkpointing for mem saving",
    )

    parser.add_argument("--batch_size", type=int, default=128, help="input batch size")
    parser.add_argument("--num_epoch", type=int, default=500)

    parser.add_argument("--lr", type=float, default=5e-4, help="learning rate g")

    parser.add_argument("--beta1", type=float, default=0.5, help="beta1 for adam")
    parser.add_argument("--beta2", type=float, default=0.9, help="beta2 for adam")
    parser.add_argument("--no_lr_decay", action="store_true", default=False)

    parser.add_argument("--use_ema", action="store_true", default=False, help="use EMA or not")
    parser.add_argument("--ema_decay", type=float, default=0.9999, help="decay rate for EMA")

    parser.add_argument("--save_content", action="store_true", default=False)
    parser.add_argument(
        "--save_content_every",
        type=int,
        default=10,
        help="save content for resuming every x epochs",
    )
    parser.add_argument("--save_ckpt_every", type=int, default=25, help="save ckpt every x epochs")
    parser.add_argument("--plot_every", type=int, default=5, help="plot every x epochs")
    parser.add_argument("--dt", type=int, default=0.01, help="delta t")

    args = parser.parse_args()
    train(args)
