import os
import datetime
import cv2
import json
import torch
import argparse
import numpy as np

from torch.utils.tensorboard import SummaryWriter
import torch.optim as opt
import torch.optim.lr_scheduler as lrs


from latent_diffusion import LatentDiffusionWraper
from embedding_manger import EmbeddingWraper
from MyClip import ClipWraper
from loss import LossWarper
from utils import log_creator, set_seed, print_args


def main(cfg):
    # load task config
    tsk_cfg = cfg["task_config"]

    # set seed
    set_seed(tsk_cfg["seed"])
    
    # create log
    logger = log_creator(
        os.path.join(tsk_cfg['log_path'], "train." + 
                        str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M")) + ".log")
    )
    writter = SummaryWriter(tsk_cfg['log_path'])
    print_args(cfg, logger)

    if "debug" in tsk_cfg.keys() and tsk_cfg["debug"]:
        torch.autograd.set_detect_anomaly(True)

    # model
    ## generator
    generator = LatentDiffusionWraper("Stable_Diffusion", cfg["generator"]["wraper"], logger)
    logger.info("Build generator!")
    ## clip
    aligner = ClipWraper("Clip", cfg["clip"]["wraper"], logger)
    logger.info("Build aligner!")
    ## prompt_learners
    prompt_learner = EmbeddingWraper("Embedding", cfg["embedding_manger"]["wraper"], logger, 
                                    initializer_words=[tsk_cfg["erased_concept"]],
                                    embedder=generator.diffusion.cond_stage_model)
    logger.info("Build prompt learner!")


    # loss
    stage_m_loss = LossWarper(cfg["loss"]["stage_m"])
    stage_c_loss = LossWarper(cfg["loss"]["stage_c"])

    # =============    
    # formal training
    epoch_global = 0
    restart_flag = False

    # reference image
    prompt = [tsk_cfg["pival"]] * tsk_cfg["batch_size_stage_m"]
    pivals = generator.generate(prompt, prompt_manager=None, verbose=False)
    pivals_f = aligner.encode_image(aligner.preprocess(pivals)[0], ret_cls=True)

    while 1:
        # prepare
        stage_m_optim = getattr(opt, cfg["embedding_manger"]['optim']["optimizer"]['type'])(
            prompt_learner.model.parameters(), 
            **cfg["embedding_manger"]['optim']["optimizer"]['kwargs']
        )
        stage_c_optim = getattr(opt, cfg["generator"]['optim']["optimizer"]['type'])(
            iter(generator.finetuned_params),
            **cfg["generator"]['optim']["optimizer"]['kwargs']
        )
        stage_m_scheduler, stage_c_scheduler = None, None
        if "scheduler" in cfg["embedding_manger"]['optim'].keys():
            stage_m_scheduler = getattr(lrs, 
                                        cfg["embedding_manger"]['optim']["scheduler"]['type'])(
                                        stage_m_optim, 
                                        **cfg["embedding_manger"]['optim']["scheduler"]['kwargs']
                                        )
        if "scheduler" in cfg["generator"]['optim'].keys():
            stage_c_scheduler = getattr(lrs, 
                                        cfg["generator"]['optim']["scheduler"]['type'])(
                                        stage_c_optim, 
                                        **cfg["generator"]['optim']["scheduler"]['kwargs']
                                        )
        stage_c_warmup = lrs.LinearLR(stage_c_optim, 1e-3, 1, 50)
        
        # sample images (generate in each epoch)
        prompt = [" ".join(["a", tsk_cfg["prefix"].replace("*", tsk_cfg["erased_concept"]), "photo"])] * tsk_cfg["batch_size_stage_m"]
        refs = generator.generate_with_original_unet(prompt, prompt_manager=None, verbose=False)
        save_sample(refs, tsk_cfg['log_path'] + "/checkpoints/samples", "ref_epoch{}_".format(epoch_global))
        refs_f = aligner.encode_image(aligner.preprocess(refs)[0], ret_cls=True)

    # =============
    # mining stage
        if restart_flag:
            prompt_learner.restart(use_initial_embedding=False)
        else:
            restart_flag = True

        for epoch_stage_m in range(tsk_cfg["epochs_stage_m"]):
            batch = {"img": refs, 
                     "txt": [tsk_cfg["prefix"]] * tsk_cfg["batch_size_stage_m"], 
                     "embedding_pool": prompt_learner.get_embedding_tool() if hasattr(prompt_learner.model, "string_to_embedding_pool") else None,
                     "cur_embedding": prompt_learner.get_embeddings()}

            output = generator.eval(batch=batch, prompt_manager=prompt_learner.model)

            loss_dict = stage_m_loss(batch, output)
            loss_dict["all_loss"].backward()

            # scale grad
            torch.nn.utils.clip_grad_norm_(prompt_learner.parameters(), 10.0)
                
            stage_m_optim.step()
            if stage_m_scheduler is not None:
                stage_m_scheduler.step()
            stage_m_optim.zero_grad(set_to_none=True)
                
            # write in TensorBoard
            for name, value in loss_dict.items():
                writter.add_scalar(
                        "epoch_{}.stage_m.{}".format(epoch_global, name), 
                        value.detach().cpu().item(),
                        epoch_stage_m
                    )
            for i in range(len(stage_m_optim.param_groups)):
                writter.add_scalar("epoch_{}.stage_m.lr".format(epoch_global), stage_m_optim.param_groups[i]['lr'], epoch_stage_m)
                
            # print to log
            info = "Epoch {}: Stage mining Epoch {}: Loss {:.4f}".format( \
                    epoch_global, epoch_stage_m, loss_dict["all_loss"].detach().cpu().item() \
            )
            for n, v in loss_dict.items():
                if "all" not in n:
                    info += " {} {:.6f}".format(n, v.detach().cpu().item())
            logger.info(info)

        prompt_learner.save_checkpoint(epoch_global, tsk_cfg["log_path"])

    # =============
    # verifying stage 
        # generate images using mined embeddings
        prompt = [tsk_cfg["prefix"]] * tsk_cfg["batch_size_stage_m"]
        samples = generator.generate(prompt, prompt_manager=prompt_learner.model, verbose=False)
        save_sample(samples, tsk_cfg['log_path'] + "/checkpoints/samples", "vefore_circumventing_epoch_{}".format(epoch_global))
        samples_f = aligner.encode_image(aligner.preprocess(samples)[0], ret_cls=True)

        # calculate delta features
        ref_pival = (refs_f - pivals_f).squeeze(1)
        sample_pival = (samples_f - pivals_f).squeeze(1)
        ref_pival = ref_pival / ref_pival.norm(dim=1, keepdim=True)
        sample_pival = sample_pival / sample_pival.norm(dim=1, keepdim=True)
        matrix = ref_pival @ sample_pival.T
        logger.info(f'Epoch {epoch_global} Verifying indiactor {matrix.mean().item()}')
        writter.add_scalar("Verifying_indiactor", matrix.mean().item(), epoch_global)

        if matrix.mean() < tsk_cfg["thr_continue"]:
            break
        
    # =============
    # circumventing stage
        generator.to_train(verbose=False)
        for epoch_stage_c in range(tsk_cfg["epochs_stage_c"]):

            cond = [tsk_cfg["prefix"]]
            pival_cond = [tsk_cfg["pival"]]

            output = generator.train(cond, pival_cond, prompt_learner.model, batch_size=tsk_cfg["batch_size_stage_c"])
            
            loss_dict = stage_c_loss(None, output)

            loss_dict["all_loss"].backward()

            # scale grad
            torch.nn.utils.clip_grad_norm_(generator.finetuned_params, 100.0)

            stage_c_optim.step()
            if stage_c_scheduler is not None:
                stage_c_scheduler.step()
            stage_c_warmup.step()
            stage_c_optim.zero_grad(set_to_none=True)

            # write in TensorBoard
            for name, value in loss_dict.items():
                writter.add_scalar(
                        "epoch_{}.stage_c.{}".format(epoch_global, name), 
                        value.detach().cpu().item(),
                        epoch_stage_c
                    )
            for i in range(len(stage_c_optim.param_groups)):
                writter.add_scalar("epoch_{}.stage_c.lr".format(epoch_global), stage_c_optim.param_groups[i]['lr'], epoch_stage_c)
            
            # print to log
            info = "Epoch {}: Stage circumventing Epoch {}: Loss {:.6f}".format( \
                        epoch_global, epoch_stage_c, loss_dict["all_loss"].detach().cpu().item()\
            )
            for n, v in loss_dict.items():
                if "all_loss" not in n:
                    info += " {} {:.6f}".format(n, v.detach().cpu().item())
            logger.info(info)
        
        generator.to_eval()

        # generate images after
        prompt = [tsk_cfg["prefix"]] * tsk_cfg["batch_size_stage_m"]
        samples = generator.generate(prompt, prompt_manager=prompt_learner.model, verbose=False)
        save_sample(samples, tsk_cfg['log_path'] + "/checkpoints/samples", "after_circumventing_epoch{}_".format(epoch_global))

        generator.save_checkpoint(None, epoch_global, tsk_cfg["log_path"])

        epoch_global += 1


def save_sample(samples: torch.Tensor, path, prefix):
    if not os.path.exists(path):
        os.makedirs(path)

    samples = ((samples + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    samples = samples.permute(0, 2, 3, 1) # RGB
    samples = samples.contiguous().detach().cpu().numpy()

    for idx in range(samples.shape[0]):
        sample = cv2.cvtColor(samples[idx], cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(
            path,
            prefix + str(idx)+".png"), sample
        )



if __name__ == "__main__":
    arg = argparse.ArgumentParser()
    arg.add_argument(
        "--config_file",
        default="",
        type=str,
    )
    arg = arg.parse_args()

    f = open(arg.config_file, "r")
    cfg = json.load(f)

    main(cfg)
