import os.path
from tqdm import tqdm

import torch as t
import pandas as pd
from PIL import Image
from daam import trace
import torch.nn.functional as F
import numpy as np

from utils import utils
from utils.itdiffusion import DiffusionModel  # Info theoretic diffusion library and flow sampler
from utils.stablediffusion import StableDiffuser

from configs.image_edit_configs import parse_args_and_update_config


def main():
    config = parse_args_and_update_config()
    t.manual_seed(config.seed)

    # set hyper-parameters
    data_in_dir = config.data_in_dir
    csv_name = config.csv_name
    res_out_dir = config.res_out_dir
    n_samples_per_point = config.n_samples_per_point
    batch_size = config.batch_size
    num_steps = config.num_steps
    sdm_version = config.sdm_version
    clip = config.clip
    dataset_type = config.dataset_type
    save_freq = config.save_freq

    # load diffusion models
    if sdm_version == 'sdm_2_0_base':
        sdm = StableDiffuser("stabilityai/stable-diffusion-2-base")
    elif sdm_version == 'sdm_2_1_base':
        sdm = StableDiffuser("stabilityai/stable-diffusion-2-1-base")

    logsnr_max, logsnr_min = sdm.logsnr_max, sdm.logsnr_min
    logsnr_loc = logsnr_min + 0.5 * (logsnr_max - logsnr_min)
    logsnr_scale = (1. / (2. * clip)) * (logsnr_max - logsnr_min)

    latent_shape = (sdm.channels, sdm.width, sdm.height)
    itd = DiffusionModel(sdm.unet, latent_shape, logsnr_loc=logsnr_loc, logsnr_scale=logsnr_scale, clip=clip,
                         logsnr2t=sdm.logsnr2t).to(sdm.device)

    # Defines range of sigma/snr to use during sampling, based on training
    
    sigma_min, sigma_max = utils.logsnr2sigma(logsnr_max), utils.logsnr2sigma(logsnr_min)

    # Set schedule in Karras et al terms, "sigmas", where z = x + sigma epsilon.
    schedule = utils.get_sigmas_karras(num_steps, sigma_min, sigma_max, device=itd.device)
    # For generation, use schedule. For reversible sampling use the following, which
    # doesn't go all the way to the limit sigma=0, snr=inf We can't approx score there so can't reverse
    schedule_reversible = schedule[:-1]

    # Step function for ODE flow. Choose second order "Heun" solver, s_churn = 0. gives deterministic
    step_function = utils.get_step(order=2, s_churn=0.)

    # load data
    img_dir = os.path.join(data_in_dir, f'val2017')
    csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
    annotation_file = os.path.join(data_in_dir, f'annotations/instances_val2017.json')

    if dataset_type == "COCO-IT":
        dataset = utils.CocoDataset(img_dir, annotation_file, csv_dir)
    elif dataset_type == "coco_ours":
        dataset = utils.CocoDatasetOurs(img_dir, annotation_file, csv_dir)
    elif dataset_type == "custom":
        dataset = utils.CustomDataset(csv_dir)

    dataloader = t.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

    if not os.path.exists(res_out_dir):
        os.makedirs(res_out_dir)

    orig = []
    mod = []

    for batch_idx, batch in tqdm(enumerate(dataloader)):
        print(f"\nProcessing {batch_idx}\n")

        try:
            if dataset_type == "COCO-IT":
                batch['full'] = batch['caption']
                batch['obj1'] = batch['category']
                batch['obj2'] = batch['context']

            img = batch['image'][0]
            prompt = batch['full'][0]
            objs = batch['obj1'] + batch['obj2']

            # Encode image to SD latent space
            x_real_transformed = sdm.sdm_pipe.image_processor.preprocess(img, height=512, width=512).squeeze().permute((1, 2, 0))
            x_real = sdm.encode_latents(img)

            # Run in reverse to get the latent
            latent_real = utils.reverse(sdm, step_function, schedule_reversible, x_real, prompt)

            # Encode prompts to CLIP embedding space
            v_org = sdm.encode_prompts(prompt).expand(batch_size, -1, -1)
            v_null = sdm.encode_prompts('').expand(batch_size, -1, -1)

            ######################################################
            ##                   No intervention                ##
            ######################################################
            # Then run forward (no intervention) and check recovery of real image - also track attention
            with t.cuda.amp.autocast(dtype=t.float16), t.no_grad():
                with trace(sdm.sdm_pipe) as tc:
                    recover_real = utils.generate(sdm, step_function, schedule_reversible, latent_real, prompt)

            # Decode real image without intervention
            recover_real_decode = sdm.decode_latents(recover_real)[0]
            orig.append(np.array(recover_real_decode))

            curr_mod = []
            for ix in range(len(objs)):
                mod_prompt = utils.perform_word_swaps(prompt, {objs[ix]: '_'})
                v_obj = sdm.encode_prompts(objs[ix]).expand(batch_size, -1, -1)
                v_mod = sdm.encode_prompts(mod_prompt).expand(batch_size, -1, -1)

                ######################################################
                ##              Omit & Swap intervention            ##
                ######################################################
                # Then run with a change in the prompt
                recover_mod = utils.generate(sdm, step_function, schedule_reversible, latent_real, mod_prompt)
                recover_mod_decode = sdm.decode_latents(recover_mod)[0]
                curr_mod.append(np.array(recover_mod_decode))

            curr_mod = np.stack(curr_mod)
            mod.append(curr_mod)

            if (not batch_idx % save_freq and batch_idx) or batch_idx == len(dataloader) - 1:
                data_to_save = {
                    'original': np.stack(orig),
                    'modified': np.stack(mod),
                }

                out_file_name = f'{sdm_version}-{csv_name}-{batch_idx}.pt'
                save_path = os.path.join(res_out_dir, out_file_name)
                t.save(data_to_save, save_path)

                orig = []
                mod = []

        except Exception as e:
            print(f"An error occurred: {e}")

if __name__ == '__main__':
    main()
