import torch
import torch.nn as nn
import numpy as np
from typing import Tuple, Optional, Callable
from scipy.integrate import solve_ivp
from torchdiffeq import odeint, odeint_adjoint
import math
from matplotlib import pyplot as plt
from tqdm import tqdm

from dataset_preparation import CenterSquareCrop
from torchvision import transforms
from einops import rearrange
from PIL import Image
from omegaconf import OmegaConf

from stable_diffusion.ldm.util import log_txt_as_img, exists, default

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from sampling.stable_diffusion.sample_compvis_single import load_model_from_config

import os
from argparse import ArgumentParser


def main():
    parser = ArgumentParser()
    parser.add_argument("--config", default="configs/fade_sd.yaml", type=str)
    parser.add_argument("--unlearn_ckpt", required=True, type=str)
    parser.add_argument("--retain_ckpt", required=True, type=str)
    parser.add_argument("--data_path", required=True, type=str)
    parser.add_argument("--interval", default=10, type=int)
    parser.add_argument("--samples", default=1, type=int)
    parser.add_argument("--theme", required=True, type=str)
    parser.add_argument("--tag", required=True, type=str)
    args = parser.parse_args()

    config_path = args.config
    ckpt_path_u = args.unlearn_ckpt #"../machine_unlearning/mu_concept_ablation_ca/logs/ca_monet/checkpoints/last.ckpt"
    ckpt_path_r = args.retain_ckpt #"./logs/train_sd_retain_original_sd_unlearn_canvas_retain_original/checkpoints/trainstep_checkpoints/epoch=000999-step=000006999.ckpt"
    device = "cuda:0"
    config = OmegaConf.load(config_path)
    config.data.params.validation.params.path=args.data_path
    config.data.params.validation.params.theme=args.theme

    # Load Unlearned Model
    model_u = instantiate_from_config(config.model)
    model_u.load_state_dict(torch.load(ckpt_path_u, map_location="cpu")["state_dict"], strict=False)
    model_u = model_u.to(device)
    model_u.eval()

    # Load Retain Model
    model_r = instantiate_from_config(config.model)
    model_r.load_state_dict(torch.load(ckpt_path_r, map_location="cpu")["state_dict"], strict=False)
    model_r = model_r.to(device)
    model_r.eval()

    # Load Dataset
    data = instantiate_from_config(config.data)
    data.prepare_data()
    data.setup()
    dataloader = data.val_dataloader()

    losses_u = []
    losses_r = []
    filenames = []

    for batch in dataloader:

        loss_vlb_u_sum = torch.zeros(batch['edited'].shape[0])
        loss_vlb_r_sum = torch.zeros(batch['edited'].shape[0])

        for timestep in tqdm(range(1, model_u.num_timesteps, args.interval)):

            with torch.no_grad():
                x, cond = model_u.get_input(batch, model_u.first_stage_key)
                t = timestep * torch.ones(x.shape[0], device=device).long()

                noise = torch.randn_like(x) # sample epsilon
                target = noise # model predicts the noise
                x_noisy = model_u.q_sample(x_start=x, t=t, noise=noise) # combine with image

                model_output_u = model_u.apply_model(x_noisy, t, cond)
                loss_vlb_u = model_u.get_loss(model_output_u, target, mean=False).mean(dim=(1, 2, 3))
                loss_vlb_u = (model_u.lvlb_weights[t] * loss_vlb_u)
                loss_vlb_u_sum += loss_vlb_u.to('cpu')

                model_output_r = model_r.apply_model(x_noisy, t, cond)
                loss_vlb_r = model_r.get_loss(model_output_r, target, mean=False).mean(dim=(1, 2, 3))
                loss_vlb_r = (model_r.lvlb_weights[t] * loss_vlb_r)
                loss_vlb_r_sum += loss_vlb_r.to('cpu')

        losses_u.append(loss_vlb_u_sum)
        losses_r.append(loss_vlb_r_sum)
        filenames.extend(batch['filename'])

    losses_u = torch.cat(losses_u)
    losses_r = torch.cat(losses_r)
    torch.save({"unlearned": losses_u, 
                "retain": losses_r, 
                "filenames": filenames, 
                "interval": args.interval,
                "samples": args.samples}, f"output_fade/{args.tag}.pt")

if __name__ == "__main__":
    main()