import os
import torch
import random
import json
import numpy as np

from copy import deepcopy
from tqdm import trange
from argparse import ArgumentParser
from diffusers import AutoPipelineForText2Image

def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--model_id', type=str, default='stabilityai/stable-diffusion-3-medium-diffusers')
    parser.add_argument('--save_name', type=str, default='sd3')
    parser.add_argument('--latent_shape', type = list, default = [16, 128, 128])
    parser.add_argument('--data_path', type=str, default='m3t2ibench_data/')
    parser.add_argument('--output_dir', type = str, default = './outputs/')
    parser.add_argument('--seed', type = int, default = 42)
    parser.add_argument('--device', type = str, default = 'cpu')
    parser.add_argument('--start_index', type = int, default = 0)
    parser.add_argument('--end_index', type = int, default = 100)
    return parser.parse_args()

def is_serializable(obj):
    try:
        json.dumps(obj)
        return True
    except (TypeError, OverflowError):
        return False

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True 

def prepare_all_args(args, dt, generator):
    latents = torch.randn(args.latent_shape, device = args.device, dtype = torch.float16, generator = generator)
    latents = latents.unsqueeze(0)
    base_args = {
        'prompt' : dt['prompt'],
        'latents' : latents,
    }
    return base_args

def save_image(base_path, i, image, generation_args, dt):  
    now_path = os.path.join(base_path, f'{i}')
    os.makedirs(now_path, exist_ok = True)
    
    image[0].save(os.path.join(now_path, 'gen.png'))
    np.save(os.path.join(now_path, 'latents.npy'), generation_args['latents'].detach().cpu().to(torch.float16).numpy())
    
    dt['gen_image'] = os.path.join(now_path, 'gen.png')
    dt['generation_args'] = {}
    for k in generation_args:
        if not is_serializable(generation_args[k]):
            continue
        dt['generation_args'][k] = generation_args[k]
    with open(os.path.join(now_path, 'metadata.json'), 'w') as f:
        json.dump(dt, f, indent = 4)

def get_input_list(args):
    inputs = []
    for i in range(10000):
        with open(os.path.join(args.data_path, f'{i}.json'), 'r') as f:
            dt = json.load(f)
            inputs.append(dt)
    return inputs

@torch.no_grad()
def generate_image(args, base_args, pipe, base_path, idx, dt):
    images = pipe(**base_args).images
    base_args['args'] = vars(args)
    save_image(base_path, idx, images, base_args, dt)


if __name__ == '__main__':
    args = parse_args()
    set_seed(args.seed)
    device = args.device
    generator = torch.Generator(device=device).manual_seed(args.seed)
    
    
    os.makedirs(args.output_dir, exist_ok = True)
    now_path = args.output_dir
    
    base_path = os.path.join(now_path, f'{args.save_name}')
    os.makedirs(base_path, exist_ok = True)
    
    pipe= AutoPipelineForText2Image.from_pretrained(args.model_id, 
                                    torch_dtype = torch.float16).to(args.device)
    
    input_jsons = get_input_list(args)
    args.end_index = len(input_jsons) if args.end_index == -1 else args.end_index
    
    for i in trange(args.start_index, args.end_index):
        torch.cuda.empty_cache()
        dt = input_jsons[i]
        base_args = prepare_all_args(args, dt, generator)
        idx = deepcopy(i)
        temp_dt = deepcopy(dt)
        generate_image(args, base_args, pipe, base_path, idx, temp_dt)