import yaml 
import logging
logger = logging.getLogger()
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)

import torch
import torch.optim as optim
import torchvision.transforms as transforms
from diffusers import DDIMScheduler
from datasets import load_dataset, concatenate_datasets

from main.wmdiffusion import WMDetectStableDiffusionPipeline
from main.wmpatch import GTWatermark, GTWatermarkMulti
from main.wmpatch_cached import GTWatermarkCached
from main.wmpach_cached_learned_wm import GTWatermarkCachedLearnedWM
from main.utils import *
from loss.loss import LossProvider
from loss.pytorch_ssim import ssim
from main.dataset import *
from main.nf_flow_models import *
from torch.utils.data import Dataset, DataLoader
import time

from data.data_loader_functions import *

import argparse

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False  # This can be set to True for faster performance, but may not be deterministic

# --------------------------------------------------------------------------------------------#

def init_pipelines(cfgs, device, batch_size):
    # sd pipeline
    scheduler = DDIMScheduler.from_pretrained(cfgs['model_id'], subfolder="scheduler")
    pipe = WMDetectStableDiffusionPipeline.from_pretrained(cfgs['model_id'], scheduler=scheduler).to(device)
    pipe.set_progress_bar_config(disable=True)
    return pipe

# --------------------------------------------------------------------------------------------#

def get_init_latent(img_tensor, pipe, text_embeddings, guidance_scale=1.0):
    # DDIM inversion from the given image
    img_latents = pipe.get_image_latents(img_tensor, sample=False)
    reversed_latents = pipe.forward_diffusion(
        latents=img_latents,
        text_embeddings=text_embeddings,
        guidance_scale=guidance_scale,
        num_inference_steps=50,
    )
    return reversed_latents
            
def main(args):
    print(args)
    # dataloader
    logging.info(f'===== Load Config =====')
    with open(args.cfg_path, 'r') as file:
        cfgs = yaml.safe_load(file)
    logging.info(cfgs)

    device = torch.device('cuda')

    # get train and val dataset (TODO: migrate to better impl)
    logging.info(f'===== Load Dataset =====')
    if args.dataset == 'all':
        dataloader_train = create_dataloder_all(args, is_train=True)
    else:
        dataloader_train = create_dataloder(args, is_train=True)
    
    logging.info(f"Num images: {len(dataloader_train)}")

    # create maps here
    if args.inv_map_type == "residual_complex_learned_wm":
        invertible_map_real = create_invertible_residual_basic(args.layers).to(device)
        invertible_map_imag = create_invertible_residual_basic(args.layers).to(device)
    else:
        raise NotImplementedError
     
    # define some optimizer and scheduler
    invertible_map_real.train()
    invertible_map_imag.train()
    optimizer = optim.Adam([{'params': invertible_map_real.parameters()}, {'params': invertible_map_imag.parameters()}], lr=args.lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.3) 

    # define a loss
    totalLoss = LossProvider(cfgs['loss_weights'], device)

    # define a prompt
    prompt = [''] * args.batch_size

    logging.info(f'===== Init Pipeline =====')
    pipe = init_pipelines(cfgs, device, args.batch_size)
    loss_mse = torch.nn.MSELoss(reduction='mean')

    for i in range(args.epochs):
        logging.info(f"Epoch: {i + 1}")
        losses = {}
        losses['total_train_loss'] = 0.
        losses['W'] = 0.
        losses['I'] = 0.
        losses['P'] = 0.
        losses['S'] = 0.

        counter = 0
        for j, batch in enumerate(dataloader_train):
            counter+=1
            # put everythiing on device
            # init_latents_approx = batch['latents'].detach().clone().squeeze().to(device)
            imgs = batch['image'].detach().clone().to(device)
            init_latents_approx = get_init_latent(imgs, pipe, pipe.get_text_embedding(prompt)).detach().clone().to(device)
            fft_init_latents_approx = torch.fft.fftshift(torch.fft.fft2(init_latents_approx), dim=(-1, -2))
            
            # for each image cache some latents
            wm_pipe = GTWatermarkCachedLearnedWM(device, batch['watermarking_mask'].squeeze(), wm_pattern=args.wm_pattern, w_radius=args.wm_radius, shape=args.shape)
            gt_patch = batch['gt_patch'].detach().clone().to(device)


            # get learned watermark (trainable watermark section)
            fft_init_latents_approx_with_patch = fft_init_latents_approx
            fft_init_latents_approx_with_patch[wm_pipe.watermarking_mask] = gt_patch[wm_pipe.watermarking_mask]

            learned_wm_real = invertible_map_real.forward(fft_init_latents_approx_with_patch.real)
            learned_wm_imag = invertible_map_imag.forward(fft_init_latents_approx_with_patch.imag)

            learned_wm = torch.complex(learned_wm_real, learned_wm_imag)

            # inject learned watermark into the latent
            init_latents_wm = wm_pipe.inject_watermark(learned_wm, init_latents_approx)

            if cfgs['empty_prompt']:
                pred_img_tensor = pipe(prompt, guidance_scale=1.0, num_inference_steps=50, output_type='tensor', use_trainable_latents=True, init_latents=init_latents_wm).images
            else:
                pred_img_tensor = pipe(prompt, num_inference_steps=50, output_type='tensor', use_trainable_latents=True, init_latents=init_latents_wm).images
            
            
            loss, lossI, lossP, lossS = totalLoss(pred_img_tensor, imgs, init_latents_wm, wm_pipe, invertible_map=None, mode='single', batch_size=args.batch_size)

            l2_loss_real = -1*loss_mse(fft_init_latents_approx.real[wm_pipe.watermarking_mask], learned_wm_real[wm_pipe.watermarking_mask])
            l2_loss_imag = -1*loss_mse(fft_init_latents_approx.imag[wm_pipe.watermarking_mask], learned_wm_imag[wm_pipe.watermarking_mask])
            
            if args.clamp:
                l2_loss = torch.clamp(l2_loss_real + l2_loss_imag, min=-20000)
            else:
                l2_loss = l2_loss_real + l2_loss_imag
            
            loss += l2_loss*args.loss_weight

            losses['total_train_loss'] += loss.item()
            losses['I'] = lossI.item()
            losses['P'] = lossP.item()
            losses['S'] = lossS.item()
            losses['L2'] = l2_loss.item()
            
            optimizer.zero_grad()
            loss.backward()

            # gradient clipping
            if args.clip:
                torch.nn.utils.clip_grad_norm_(invertible_map_real.parameters(), 1.0)
                torch.nn.utils.clip_grad_norm_(invertible_map_imag.parameters(), 1.0)
            
            optimizer.step()
            scheduler.step()


            # save step
            if ((i)*args.num_images + (j + 1)) % args.save_interval == 0:
                with open(args.write_file, "a") as file1:
                    # Writing data to a file
                    file1.write(f"Epoch {i}, Step {j} Loss: {losses['total_train_loss']/counter} \n")
                    file1.write(f"W: {losses['W']/counter}, I: {losses['I']/counter}, P: {losses['P']/counter}, S: {losses['S']/counter}\n")
                    file1.write(f"L2: {losses['L2']/counter}\n\n")
                if args.inv_map_type == "mlp" or args.inv_map_type == "unet" or args.inv_map_type == "residual_unet":
                    torch.save(invertible_map_real.state_dict(), args.wm_path + "/" + f"model_real_{i}_step_{j}.pt")
                    torch.save(invertible_map_imag.state_dict(), args.wm_path + "/" + f"model_imag_{i}_step_{j}.pt")
                else:
                    invertible_map_real.save(args.wm_path + "/" + f"model_real_{i}_step_{j}.pt")
                    invertible_map_imag.save(args.wm_path + "/" + f"model_imag_{i}_step_{j}.pt")
        
        print(losses['total_train_loss'])
        with open(args.write_file, "a") as file1:
            # Writing data to a file
            file1.write(f"Epoch {i} End Loss: {losses['total_train_loss']/counter} \n")
            file1.write(f"W: {losses['W']/counter}, I: {losses['I']/counter}, P: {losses['P']/counter}, S: {losses['S']/counter}\n")
            file1.write(f"L2: {losses['L2']/counter}\n\n")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='diffusion watermark')
    parser.add_argument ('--dataset', default='diffusiondb', choices=['coco', 'diffusiondb', 'wikiart', 'all'])
    parser.add_argument ('--dataset_path', default='/localhome/data/datasets/watermarking')
    parser.add_argument ("--use_cached_latents", action="store_true", help="directly use latents to train instead of images")
    parser.add_argument ('--seed', default=0, type=int)
    parser.add_argument ('--wm_path', default='wm_examples')
    parser.add_argument ('--cfg_path', default='./example/config/config.yaml')
    parser.add_argument ('--method', default='ours')
    parser.add_argument ('--inv_map_type', default='glow')
    parser.add_argument ('--num_images', default=300, type=int)
    parser.add_argument ('--layers', default=8, type=int)
    parser.add_argument ('--num_workers', default=1, type=int)
    parser.add_argument ('--write_file', default='temp', type=str)
    parser.add_argument ('--wm_pattern', default='rings', type=str)
    parser.add_argument ('--epochs', default=100, type=int)
    parser.add_argument ('--lr', default=0.0001, type=float)
    parser.add_argument ('--batch_size', default=8, type=int)
    parser.add_argument ('--wm_radius', type=int, default=10)
    parser.add_argument ('--shape', type=str, default='circle')
    parser.add_argument ('--loss_weight', default=1e-4, type=float)
    parser.add_argument ('--save_interval', type=int, default=300)
    parser.add_argument ('--clamp', action='store_true', default=False)
    parser.add_argument ('--clip', action='store_true', default=False)
    args = parser.parse_args()

    # set seeds
    torch.cuda.manual_seed_all(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    
    main(args)

