import os
import torch
import logging
import argparse
import itertools
from tqdm import tqdm

from src.utils import *
from lora_diffusion import inject_trainable_lora
from src.inversion.inv_pipe import InversionPipeline
from src.fari import inject_fari, save_fari_model, one_step_inversion

def main(args):
    output_path = os.path.join(args.output_dir, args.name)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    fh = logging.FileHandler(f'./{output_path}/TRAIN.log')
    fh.setLevel(logging.INFO)
    logger.addHandler(fh)

    settings = vars(args)
    print(settings)
    with open(f"./{output_path}/training_settings.json", "w") as f:
        json.dump(settings, f, indent=4)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pipe = InversionPipeline.from_pretrained(args.model_id).to(device)
    pipe.set_progress_bar_config(disable=True)
    null_text_embeds, _ = pipe.encode_prompt("", pipe._execution_device, 1, False)
    pipe.unet.requires_grad_(False)
    pipe.text_encoder.requires_grad_(False)
    pipe.vae.requires_grad_(False)

    trainable_params, _ = inject_trainable_lora(pipe.unet, r=args.lora_r)
    pipe.unet = inject_fari(pipe.unet)
    optimizer = torch.optim.Adam(itertools.chain(*trainable_params), lr=args.lr)
    get_num_params(optimizer)
    loss_fn = torch.nn.MSELoss()

    with tqdm(range(args.steps), desc="Training FARI") as pbar:
        for i, prompt in enumerate(load_prompt(args.train_dataset_id)):
            all_initial_latents = torch.zeros(args.batch_size,4,64,64)
            all_noised_latents = torch.zeros(args.batch_size,4,64,64)
            optimizer.zero_grad()
            for j in range(args.batch_size):
                initial_latents = torch.randn(1, 4, 64, 64).to(device)
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    images = pipe(
                        prompt,
                        num_images_per_prompt=1,
                        guidance_scale=args.guidance_scale,
                        num_inference_steps=args.num_inference_steps,
                        height=512,
                        width=512,
                        latents=initial_latents,
                    ).images[0]
                noised_images, _ = image_distortion(images, i, args)
                noised_image_tensor = to_tensor(noised_images).to(device)
                noised_image_latent = pipe.get_image_latents(noised_image_tensor, sample=False)
                noised_latents = one_step_inversion(pipe, noised_image_latent, null_text_embeds.detach())
                all_initial_latents[j, :, :, :] = initial_latents[0, :, :, :]
                all_noised_latents[j, :, :, :] = noised_latents[0, :, :, :]
            loss = loss_fn(all_noised_latents.cuda(), all_initial_latents.cuda())
            loss.backward()
            optimizer.step()
            logger.info(f"Iter {i} Loss {loss.item()}")
            pbar.update(1)

            if i + 1 == args.steps:
                save_fari_model(pipe.unet, f"./{output_path}/fari_weights.pth")
                logger.info(f"Iter {i} Save model")
                break

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='FARI Training')

    parser.add_argument("--name", type=str, default="test")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--output_dir", type=str, default="results")
    parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-2-1-base")
    parser.add_argument("--guidance_scale", type=float, default=7.5)
    parser.add_argument("--num_inference_steps", type=int, default=20)
    parser.add_argument("--train_dataset_id", type=str, default="/home/xxx/fid_outputs/coco/meta_data.json")
    parser.add_argument("--steps", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--lora_r", type=int, default=8)  # LoRA rank

    parser.add_argument('--jpeg_ratio', type=int, default=25)
    parser.add_argument('--random_crop_ratio', type=float, default=0.6)
    parser.add_argument('--random_drop_ratio', type=float, default=0.8)
    parser.add_argument('--gaussian_blur_r', type=int, default=4)
    parser.add_argument('--median_blur_k', type=int, default=7)
    parser.add_argument('--resize_ratio', type=float, default=0.25)
    parser.add_argument('--gaussian_std', type=float, default=0.05)
    parser.add_argument('--sp_prob', type=float, default=0.05)
    parser.add_argument('--brightness_factor', type=float, default=6)

    args = parser.parse_args()
    main(args)

