# for some reason transformers doesn't want to be imported after torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
    DDIMScheduler,
    DDIMPipeline,
    StableDiffusionPipeline,
)
from diffusers.utils import convert_state_dict_to_diffusers
from accelerate import Accelerator
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from diffusers.optimization import get_scheduler
import json

import torch
import torch.nn.functional as F
import torch.nn as nn

from dataset import ImageDataset
from utils import (
    encode_text,
    generate_image,
    cast_training_params,
    torch2latents,
)

from PIL import Image
import os
from copy import deepcopy
from functools import partial
import warnings
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import argparse
import itertools

def get_args(should_parse=True):

    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt", type=str, default="A <rare_token> caution wet floor sign")
    parser.add_argument("--ic_light_prompt", type=str, default="A yellow stand sign")
    # dca
    # take something from https://github.com/2kpr/dreambooth-tokens
    # <doh> is okay
    parser.add_argument("--rare_token", type=str, default="<ktn>")
    parser.add_argument("--exp_desc", type=str, default="")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--model_name", type=str, default="stabilityai/stable-diffusion-2-1-base")
    parser.add_argument("--sd_unet_path", type=str, default=None)
    parser.add_argument("--guidance_scale", type=float, default=7.5)
    parser.add_argument("--num_inference_steps", type=int, default=25)
    parser.add_argument("--width", type=int, default=512)
    parser.add_argument("--height", type=int, default=512)
    parser.add_argument("--prob", type=float, default=0.7)

    parser.add_argument("--lora_rank", type=int, default=4)

    parser.add_argument("--num_train_iterations", type=int, default=1000)
    parser.add_argument("--image_dir", type=str, default=None)
    parser.add_argument("--generate_cp_dir", type=str, default=None)
    parser.add_argument("--generate_cp_iclight", type=str, default=None)
    parser.add_argument("--generate_using_iclight", type=int, default=1)
    parser.add_argument("--show_iter", type=int, default=100)
    parser.add_argument("--generate_n", type=int, default=32)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--num_acc_step", type=int, default=1)
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--adam_beta1", type=float, default=0.9)
    parser.add_argument("--adam_beta2", type=float, default=0.999)
    parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
    parser.add_argument("--adam_epsilon", type=float, default=1e-08)
    parser.add_argument("--num_warmup_steps", type=int, default=100)
    parser.add_argument("--lr_scheduler", type=str, default="constant")
    parser.add_argument("--lr_num_cycles", type=int, default=1)
    parser.add_argument("--lr_power", type=float, default=1.0)
    parser.add_argument("--mixed_precision", type=str, default="fp16")
    # parser.add_argument("--lambd_preservation_loss", type=float, default=0.0)
    parser.add_argument("--save_dir", type=str, default="./personalization_play/")
    parser.add_argument("--debug", type=int, default=1)
    parser.add_argument("--skip_used", action="store_true")
    parser.add_argument("--fixed_place_prompt", type=str, default=None)


    if should_parse:
        args = parser.parse_args()
        return args
    
    return parser

def main():
    args = get_args()
    if args.model_name == "2.1":
        args.model_name = "stabilityai/stable-diffusion-2-1-base"
    elif args.model_name == "2.0":
        args.model_name = "stabilityai/stable-diffusion-2-base"
    elif args.model_name == "1.5":
        args.model_name = "runwayml/stable-diffusion-v1-5"
    # elif args.model_name == "3.0":
    #     args.model_name = 

    # replce <rare_token> in prompt
    prompt_object = args.prompt.replace("<rare_token> ", "")
    args.prompt = args.prompt.replace("<rare_token>", args.rare_token)

    print("prompt_object:", prompt_object)
    print("prompt_train :", args.prompt)

    os.makedirs(args.save_dir, exist_ok=True)
    if args.debug:
        args.save_dir = str(Path(args.save_dir) / f"{len(list(Path(args.save_dir).glob('*'))):03}")
    print("Save dir:", args.save_dir)
    os.makedirs(args.save_dir, exist_ok=True)
    with open(Path(args.save_dir) / "my_args.json", "w") as f:
        json.dump(vars(args), f, indent=4)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.num_acc_step,
        project_dir=os.path.join(args.save_dir, "logs"),
        mixed_precision=args.mixed_precision,
    )
    torch_device = accelerator.device
    if accelerator.is_main_process:
        accelerator.init_trackers("train_example")

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16


    pipe = StableDiffusionPipeline.from_pretrained(args.model_name, torch_dtype=weight_dtype)
    vae = pipe.vae
    tokenizer = pipe.tokenizer
    text_encoder = pipe.text_encoder
    unet = pipe.unet
    if not args.sd_unet_path is None:
        unet = unet.from_pretrained(args.sd_unet_path, device=torch_device, torch_dtype=weight_dtype)
        unet = unet.to(torch_device) # idk why but I have to do it again 0_o
        print(f"Unet from {args.sd_unet_path} loaded!")
    _scheduler = DDIMScheduler.from_pretrained(args.model_name, subfolder="scheduler")
    _train_scheduler = DDIMScheduler.from_pretrained(args.model_name, subfolder="scheduler")

    # freeze unet, vae and text_encoder
    unet.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    unet_lora_config = LoraConfig(
        r=args.lora_rank,
        lora_alpha=args.lora_rank,
        init_lora_weights="gaussian",
        target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
    )

    unet.add_adapter(unet_lora_config)
    if args.mixed_precision == "fp16":
        # only upcast trainable parameters (LoRA) into fp32
        cast_training_params(unet, dtype=torch.float32)

    lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
    torch.manual_seed(args.seed)

    model, text_encoder, vae = accelerator.prepare(
    unet, text_encoder, vae
    )

    dataset = ImageDataset(
        image_dir=args.image_dir, 
        width=args.width, 
        height=args.height,
        generate_n=args.generate_n,
        orig_obj_prompt=prompt_object,
        ic_light_prompt=args.ic_light_prompt,
        save_dir=args.save_dir,
        generate_using_iclight=(args.generate_using_iclight != 0),
        generate_cp_dir=args.generate_cp_dir,
        generate_cp_iclight=args.generate_cp_iclight,
        prob=args.prob,
        fixed_place_prompt=args.fixed_place_prompt,
    )
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    optimizer = torch.optim.AdamW(
        lora_layers,
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.num_train_iterations,
        num_cycles=args.lr_num_cycles,
        power=args.lr_power,
    )

    optimizer, dataloader, lr_scheduler = accelerator.prepare(
    optimizer, dataloader, lr_scheduler
    )

    prompt_emb = encode_text([args.prompt], text_encoder, tokenizer, device=torch_device)
    prompt_object_emb = encode_text([prompt_object], text_encoder, tokenizer, device=torch_device)
    embs = torch.cat([prompt_object_emb, prompt_emb], axis=0)

    generate_image(
        prompt=prompt_object,
        num_same=1,
        save_name="initial.png",
        seed=args.seed,
        save_dir=args.save_dir,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        height=args.height, width=args.width,
        scheduler=_scheduler,
        tokenizer=tokenizer,
        text_encoder=accelerator.unwrap_model(text_encoder),
        vae=accelerator.unwrap_model(vae),
        unet=accelerator.unwrap_model(model),
        device=accelerator.device,
    )

    iter_dataloader = iter(dataloader)
    for i in tqdm(range(args.num_train_iterations)):
        # get batch
        model.train()
        try:
            batch_images, batch_emb_ids = next(iter_dataloader)
        except StopIteration:
            iter_dataloader = iter(dataloader)
            batch_images, batch_emb_ids = next(iter_dataloader)

        clean_images = batch_images.to(device=torch_device, dtype=weight_dtype)
        clean_images = torch2latents(clean_images, vae, torch_device) # get latents
        curr_bs = clean_images.size(0)

        timesteps = torch.randint(0, _train_scheduler.config.num_train_timesteps, size=(curr_bs,), device=torch_device)
        prompt_emb_batch = embs[batch_emb_ids][:, 0, :, :]

        noise = torch.randn_like(clean_images)
        noised_images = _train_scheduler.add_noise(clean_images, noise, timesteps)

        with accelerator.accumulate(model):
            noise_pred = model(noised_images, timesteps, encoder_hidden_states=prompt_emb_batch).sample
            loss = F.mse_loss(noise_pred.float(), noise.float())
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        if accelerator.is_main_process and ((i + 1) % args.show_iter == 0 or i == args.num_train_iterations - 1):
            model.eval()
            unwrapped_unet = deepcopy(model).to(torch.float32)
            unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
            StableDiffusionPipeline.save_lora_weights(
                save_directory=args.save_dir,
                unet_lora_layers=unet_lora_state_dict,
                safe_serialization=True,
            )

            generate_image(
                prompt=args.prompt, 
                num_same=9,
                save_name="training.png",
                seed=args.seed,
                save_dir=args.save_dir,
                num_inference_steps=args.num_inference_steps,
                guidance_scale=args.guidance_scale,
                height=args.height, width=args.width,
                scheduler=_scheduler,
                tokenizer=tokenizer,
                text_encoder=text_encoder,
                vae=vae,
                unet=accelerator.unwrap_model(model),
                device=accelerator.device,
            )

if __name__ == "__main__":
    main()
