import argparse
from sqlite3 import NotSupportedError
import os
import pickle
import torch
import numpy as np
import torchvision
import clip
import math
from tqdm import tqdm

import sys
sys.path.append(".")
sys.path.append("..")

# from editing.styleclip.model import Generator

from configs import data_configs, paths_config
from datasets.inference_dataset import InferenceDataset, EditingDataset
from torch.utils.data import DataLoader
from utils.model_utils import setup_model
from utils.common import tensor2im
from PIL import Image
from editings import latent_editor, face_editor
from options.test_options import TestOptions
from tqdm import tqdm
from criteria.clip_loss import CLIPLoss
from criteria.id_loss import IDLoss
from torch import optim


def get_latents(net, x, is_cars=False):
    codes = net.encoder(x)
    if net.opts.start_from_latent_avg:
        if codes.ndim == 2:
            codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
        else:
            codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)
    if codes.shape[1] == 18 and is_cars:
        codes = codes[:, :16, :]
    return codes

def get_average_image(net):
    avg_image= net.latent_avg.repeat(18, 1, 1)[:, 0, :]
    avg_image = avg_image.to('cuda').float().detach()
    return avg_image

def get_all_latents(net, data_loader, n_images=None, is_cars=False):
    all_latents = []
    all_inputs = []
    i = 0
    with torch.no_grad():
        for batch, _ in data_loader:
            if n_images is not None and i > n_images:
                break
            x = batch
            inputs = x.to(device).float()
            latents = get_latents(net, inputs, is_cars)
            all_latents.append(latents)
            all_inputs.append(inputs)
            i += len(latents)
    return torch.cat(all_latents), torch.cat(all_inputs)

def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
    lr_ramp = min(1, (1 - t) / rampdown)
    lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
    lr_ramp = lr_ramp * min(1, t / rampup)

    return initial_lr * lr_ramp

def setup_data_loader(args, opts, test_opts):
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()

    images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root']
    print(f"images path: {images_path}")
    align_function = None

    test_dataset = EditingDataset(root=images_path,
                                landmarks_transforms_path=test_opts.landmarks_transforms_path,
                                transform=transforms_dict['transform_inference'])

    data_loader = DataLoader(test_dataset,
                             batch_size=args.batch,
                             shuffle=False,
                             num_workers=2,
                             drop_last=True)

    print(f'dataset length: {len(test_dataset)}')

    if args.n_sample is None:
        args.n_sample = len(test_dataset)
    return args, test_dataset, data_loader

def main(args, test_opts):
    net, opts = setup_model(args.ckpt, device, is_swagan=True)
    net = net.cuda()
    is_cars = 'car' in opts.dataset_type
    aligner = net.grid_align
    args, dataset, data_loader = setup_data_loader(args, opts, test_opts)

    # load latents obtained via inference
    latent_codes, inputs = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars)
    avg_image = get_average_image(net)

    latents = latent_codes
    # prepare output directory
    args.output_path = os.path.join(args.save_dir, "wagi", args.description)
    os.makedirs(args.output_path, exist_ok=True)
    # edit all images
    
    for idx, (latent, img) in enumerate(zip(latents, inputs)):
        im_path = dataset.paths[idx]
        image_name = im_path.split('/')[-1]
        if args.n_sample is not None and idx >= args.n_sample:
            break
        edit_image(image_name, latent, img, net, args)


def edit_image(image_name, latent, x, net, args):
    swagan_model = net.decoder
    text_inputs = torch.cat([clip.tokenize(args.description)]).cuda()
    clip_loss = CLIPLoss(args)
    id_loss = IDLoss()
    print(f'Editing {image_name}')

    # latent_code = torch.from_numpy(latent).cuda()
    latent_code_init=latent.cuda().unsqueeze(0) # 18*512
    latent_optim = latent_code_init.detach().clone()
    latent_optim.requires_grad = True
    truncation = 1
    mean_latent = None
    input_is_latent = True
    lr_init=0.05
    step=300
    mse_loss = torch.nn.MSELoss()

    optimizer = optim.Adam([latent_optim], lr=lr_init)

    pbar = tqdm(range(step))

    imgs_origin, _, _ = swagan_model([latent_code_init], None, None, input_is_latent=True, randomize_noise=False, return_latents=True, is_inference=True)
    imgs_origin = imgs_origin.detach()


    for i in pbar:
        t = i / step
        lr = get_lr(t, lr_init)
        optimizer.param_groups[0]["lr"] = lr
        imgs, _, _ = swagan_model([latent_optim], None, None, input_is_latent=True, randomize_noise=False, return_latents=True, is_inference=True)

        # res = x - torch.nn.functional.interpolate(torch.clamp(imgs, -1., 1.), size=(256,256) , mode='bilinear')
        # img_edit = torch.nn.functional.interpolate(torch.clamp(imgs, -1., 1.), size=(256,256) , mode='bilinear')

        res = x - torch.nn.functional.interpolate(torch.clamp(imgs_origin, -1., 1.), size=(256,256) , mode='bilinear')
        img_edit = torch.nn.functional.interpolate(torch.clamp(imgs_origin, -1., 1.), size=(256,256) , mode='bilinear')

        res_aligned  = net.grid_align(torch.cat((res, img_edit), 1))
        res_512=torch.nn.functional.interpolate(res_aligned, size=(512,512), mode='bilinear')
        fconditions = net.fresidue1(res_aligned)
        fconditions2 = net.fresidue2(res_aligned)
        fconditions_list = [fconditions, fconditions2]
        wconditions = net.wresidue(res_512) # [B*3*256*256, B*3*256,256]
        wc_scale=wconditions[0]
        wc_shift=wconditions[1]
        wc_scale=net.dwt(wc_scale)
        wc_shift=net.dwt(wc_shift)
        wconditions_list = [[wc_scale,wc_shift]]
        imgs_final, _, _ = swagan_model([latent_optim], fconditions_list, wconditions_list, input_is_latent=True, randomize_noise=False, return_latents=True, is_inference=True)        

        # imgs_final = imgs

        c_loss = clip_loss(imgs_final, text_inputs)
        # c_loss = clip_loss(imgs, text_inputs)
        l2_loss = ((latent_code_init - latent_optim) ** 2).sum()
        img_loss = mse_loss(imgs_origin, imgs)
        loss = c_loss + 0.0005 * l2_loss + img_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.set_description(
            (
                f"loss: {loss.item():.4f}; lr: {lr:.4f};"
            )
        )

    # imgs, _, _ = swagan_model([latent_optim], None, None, input_is_latent=True, randomize_noise=False, return_latents=True, is_inference=True)

    # with torch.no_grad():

    #     res = x - torch.nn.functional.interpolate(torch.clamp(imgs, -1., 1.), size=(256,256) , mode='bilinear')
    #     img_edit = torch.nn.functional.interpolate(torch.clamp(imgs, -1., 1.), size=(256,256) , mode='bilinear')
    #     res_aligned  = net.grid_align(torch.cat((res, img_edit), 1))

    #     res_512=torch.nn.functional.interpolate(res_aligned, size=(512,512), mode='bilinear')

    #     # consultation fusion

    #     fconditions = net.fresidue1(res_aligned)
    #     fconditions2 = net.fresidue2(res_aligned)

    #     fconditions_list = [fconditions, fconditions2]

    #     wconditions = net.wresidue(res_512) # [B*3*256*256, B*3*256,256]
    #     wc_scale=wconditions[0]
    #     wc_shift=wconditions[1]

    #     wc_scale=net.dwt(wc_scale)
    #     wc_shift=net.dwt(wc_shift)

    #     wconditions_list = [[wc_scale,wc_shift]]

    #     edited_image, _, _ = swagan_model([latent_optim], fconditions_list, wconditions_list, input_is_latent=True, randomize_noise=False, return_latents=True, is_inference=True)

    torchvision.utils.save_image(imgs, f"{args.output_path}/{image_name.split('.')[0]}_ori.jpg",
                                 normalize=True, range=(-1, 1), padding=0)
    torchvision.utils.save_image(imgs_final, f"{args.output_path}/{image_name.split('.')[0]}.jpg",
                                 normalize=True, range=(-1, 1), padding=0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--images_dir", type=str, default='/home/nvadmin/seungjun/CelebA-HQ-test/img', help="The directory to the images")
    parser.add_argument("--save_dir", type=str, default="./styleclip",
                        help="Path to inference results with `latents.npy` saved here (obtained with inference.py).")
    parser.add_argument("--n_sample", type=int, default=10000, help="number of the samples to infer.")
    parser.add_argument("--ckpt", metavar="CHECKPOINT", help="path to generator checkpoint")
    parser.add_argument("--stylegan_truncation", type=int, default=1.)
    parser.add_argument("--stylegan_truncation_mean", type=int, default=4096)
    parser.add_argument("--num_alphas", type=int, default=11)
    parser.add_argument("--batch", type=int, default=1, help="batch size for the generator")
    parser.add_argument("--description", type=str, default="red lipstick", help="the text that guides the editing/generation")
    args = parser.parse_args()
    device = "cuda"
    main(args, test_opts=TestOptions)
