import argparse
from sqlite3 import NotSupportedError
import torch
import numpy as np
import sys
import os

sys.path.append(".")
sys.path.append("..")

from configs import data_configs, paths_config
from datasets.inference_dataset import InferenceDataset, EditingDataset
from torch.utils.data import DataLoader
from utils.model_utils import setup_model
from utils.common import tensor2im
from PIL import Image
from editings import latent_editor, face_editor
from options.test_options import TestOptions
from tqdm import tqdm

def main(args, test_opts):
    net, opts = setup_model(args.ckpt, device, is_swagan=True)
    
    is_cars = 'car' in opts.dataset_type
    generator = net.decoder
    generator.eval()
    aligner = net.grid_align
    args, dataset, data_loader = setup_data_loader(args, opts, test_opts)
    # editor = latent_editor.LatentEditor(net.decoder, is_cars)
    editor = face_editor.FaceEditor(net.decoder)
    resize_amount = (256, 256) 

    # initial inversion
    latent_codes = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars)
    avg_image = get_average_image(net)

    # set the editing operation
    if args.edit_attribute == 'inversion':
        assert NotImplementedError
    elif args.edit_attribute == 'age' or args.edit_attribute == 'smile' or args.edit_attribute == 'pose' or args.edit_attribute == 'deage':

        interfacegan_directions = paths_config.interfacegan_edit_paths
        print(f'{args.edit_attribute} interface direction: ', interfacegan_directions[args.edit_attribute])

        # edit_direction = torch.load(interfacegan_directions[args.edit_attribute]).to(device)
    else:
        assert NotImplementedError

    edit_directory_path = os.path.join(args.save_dir, args.edit_attribute)
    os.makedirs(edit_directory_path, exist_ok=True)

    # perform high-fidelity inversion or editing
    global_i = 0
    for batch, landmarks_transform in tqdm(data_loader):

        if args.n_sample is not None and  global_i > args.n_sample:
            print('inference finished!')
            break            
        x = batch.to(device).float()

        # calculate the distortion map
        imgs, _, _ = generator([latent_codes[global_i].unsqueeze(0).to(device)],None, None, input_is_latent=True, randomize_noise=False, return_latents=True, is_inference=True)
        res = x -  torch.nn.functional.interpolate(torch.clamp(imgs, -1., 1.), size=(256,256) , mode='bilinear')


        # produce initial editing image

        if args.edit_attribute == 'age' or args.edit_attribute == 'smile' or args.edit_attribute == 'pose' or args.edit_attribute == 'deage':
            
            img_edits, edit_latents, _ = edit_batch_swagan(inputs=batch.cuda().float(),
                                      net=net,
                                      avg_image=avg_image,
                                      latent_editor=editor,
                                      opts=test_opts,
                                      landmarks_transform=landmarks_transform.cuda().float(),
                                      args=args)

            # for i in range(batch.shape[0]):
            #     im_path = dataset.paths[global_i]
            #     print(im_path)
            #     results = result_batch[i]

            #     inversion = results.pop('inversion')
            #     input_im = tensor2im(batch[i])

            #     all_edit_results = []
            #     for edit_name, edit_res in results.items():
            #         res = np.array(input_im.resize(resize_amount))  # set the input image
            #         res = np.concatenate([res, np.array(inversion.resize(resize_amount))], axis=1)  # set the inversion
            #         for result in edit_res:
            #             res = np.concatenate([res, np.array(result.resize(resize_amount))], axis=1)
            #         res_im = Image.fromarray(res)
            #         all_edit_results.append(res_im)

            #         # edit_save_dir = edit_directory_path / edit_name
            #         # edit_save_dir.mkdir(exist_ok=True, parents=True)
            #         im_path_name = im_path.split('/')[-1]
            #         res_im.save(os.path.join(edit_directory_path,im_path_name))

            #     global_i += 1

            # align the distortion map
            
            for i in range(batch.shape[0]):
                im_path = dataset.paths[global_i]
                img_edit = img_edits[i][None, :]
                edit_latent = edit_latents[i][None, :]
                res_i = res[i][None, :]

                img_edit = torch.nn.functional.interpolate(torch.clamp(img_edit, -1., 1.), size=(256,256) , mode='bilinear')
                res_aligned  = net.grid_align(torch.cat((res, img_edit  ), 1))

                res_512=torch.nn.functional.interpolate(res_aligned, size=(512,512), mode='bilinear')

                # consultation fusion

                fconditions = net.fresidue1(res_aligned)
                fconditions2 = net.fresidue2(res_aligned)

                fconditions_list = [fconditions, fconditions2]

                wconditions = net.wresidue(res_512) # [B*3*256*256, B*3*256,256]
                wc_scale=wconditions[0]
                wc_shift=wconditions[1]

                wc_scale=net.dwt(wc_scale)
                wc_shift=net.dwt(wc_shift)

                wconditions_list = [[wc_scale,wc_shift]]

                imgs, _, _ = generator([edit_latent], fconditions_list, wconditions_list, input_is_latent=True, randomize_noise=False, return_latents=True, is_inference=True)
                
                if is_cars:
                    imgs = imgs[:, :, 64:448, :]
                    
                # save images
                # imgs = torch.nn.functional.interpolate(imgs, size=(256,256) , mode='bilinear')
            result = tensor2im(imgs[0])
            img_edit = tensor2im(img_edit[0])
            res_aligned = tensor2im(res_aligned[0])
            res = tensor2im(res[0])

            im_save_path = os.path.join(edit_directory_path, sorted(os.listdir(args.images_dir))[i][:-4]+'_ori.jpg')
            Image.fromarray(np.array(img_edit)).save(im_save_path)
            
            im_save_path = os.path.join(edit_directory_path, sorted(os.listdir(args.images_dir))[i][:-4]+'_res.jpg')
            Image.fromarray(np.array(res)).save(im_save_path)

            im_save_path = os.path.join(edit_directory_path, sorted(os.listdir(args.images_dir))[i][:-4]+'_res_aligned.jpg')
            Image.fromarray(np.array(res_aligned)).save(im_save_path)

            im_save_path = os.path.join(edit_directory_path, sorted(os.listdir(args.images_dir))[i])
            Image.fromarray(np.array(result)).save(im_save_path)

                global_i += 1
        else:
            raise NotImplementedError

def setup_data_loader(args, opts, test_opts):
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()

    images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root']
    print(f"images path: {images_path}")
    align_function = None

    if args.edit_attribute == 'inversion':
        test_dataset = InferenceDataset(root=images_path,
                                        transform=transforms_dict['transform_test'],
                                        preprocess=align_function,
                                        opts=opts)
    else:
        test_dataset = EditingDataset(root=images_path,
                               landmarks_transforms_path=test_opts.landmarks_transforms_path,
                               transform=transforms_dict['transform_inference'])

    data_loader = DataLoader(test_dataset,
                             batch_size=args.batch,
                             shuffle=False,
                             num_workers=2,
                             drop_last=True)

    print(f'dataset length: {len(test_dataset)}')

    if args.n_sample is None:
        args.n_sample = len(test_dataset)
    return args, test_dataset, data_loader

def get_average_image(net):
    avg_image= net.latent_avg.repeat(18, 1, 1)[:, 0, :]
    avg_image = avg_image.to('cuda').float().detach()
    return avg_image

def get_latents(net, x, is_cars=False):
    codes = net.encoder(x)
    if net.opts.start_from_latent_avg:
        if codes.ndim == 2:
            codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
        else:
            codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)
    if codes.shape[1] == 18 and is_cars:
        codes = codes[:, :16, :]
    return codes


def get_all_latents(net, data_loader, n_images=None, is_cars=False):
    all_latents = []
    i = 0
    with torch.no_grad():
        for batch, _ in data_loader:
            if n_images is not None and i > n_images:
                break
            x = batch
            inputs = x.to(device).float()
            latents = get_latents(net, inputs, is_cars)
            all_latents.append(latents)
            i += len(latents)
    return torch.cat(all_latents)


def save_image(img, save_dir, idx):
    result = tensor2im(img)
    im_save_path = os.path.join(save_dir, f"{idx:05d}.jpg")
    Image.fromarray(np.array(result)).save(im_save_path)



#######################################
from typing import Optional
from editings.face_editor import FaceEditor

def edit_batch_swagan(inputs: torch.tensor, net, avg_image: torch.tensor, latent_editor: FaceEditor, opts: TestOptions,
               landmarks_transform: Optional[torch.tensor] = None, args=None):

    with torch.no_grad():
        codes = net.encoder(inputs)
        latent_codes = codes + avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1)
    #[b_size, 18, 512]

    imgs = []
    for i in range(len(latent_codes)):
        img, _, _ = net.decoder([latent_codes[i].unsqueeze(0)],None, None, input_is_latent=True, randomize_noise=False, return_latents=True, is_inference=True)
        img = net.face_pool(img) #[1,3,256,256]
        imgs.append(img)

    results = {idx: {'inversion': tensor2im(imgs[idx][0])} for idx in range(len(imgs))}

    if True:
        img_edit, edit_latent = latent_editor.edit_single(latents=latent_codes,
                                        direction=args.edit_attribute,
                                        edit_degree=args.edit_degree)
        return img_edit, edit_latent, _

    else:
        for factor_range in args.factor_ranges:
            edit_direction = args.edit_attribute
            print('edit attribute ', edit_direction, " with range ", factor_range)

            edit_images, _ = latent_editor.edit(latents=latent_codes,
                                        direction=edit_direction,
                                        factor_range=factor_range)

            img_edit = edit_images[5 + args.edit_degree]


            for idx in range(inputs.shape[0]):
                results[idx][edit_direction] = [step_res[idx] for step_res in edit_images]
        return img_edit, latent_codes, results
######################################

if __name__ == "__main__":
    device = "cuda"
    parser = argparse.ArgumentParser(description="Inference")
    parser.add_argument("--images_dir", type=str, default=None, help="The directory to the images")
    parser.add_argument("--save_dir", type=str, default=None, help="The directory to save.")
    parser.add_argument("--batch", type=int, default=10, help="batch size for the generator")
    parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.")
    parser.add_argument("--edit_attribute", type=str, default='pose', help="The desired attribute")
    parser.add_argument("--edit_degree", type=int, default=0, help="edit degreee")
    parser.add_argument("--factor_ranges", default=[(-5,5)], help="edit degreee")
    parser.add_argument("--ckpt", metavar="CHECKPOINT", help="path to generator checkpoint")

    args = parser.parse_args()
    main(args, test_opts=TestOptions)