import torch
import json
from training.prompting_utils import UniversalPrompting, create_attention_mask_predict_next, \
    create_attention_mask_for_mmu 
from PIL import Image
import numpy as np
import wandb
from training.utils import image_transform
from torchvision import transforms
from tqdm import tqdm
import os
import json

@torch.no_grad()
def eval_t2i(
    model,
    vq_model,
    uni_prompting,
    accelerator,
    config,
    global_step,
    logger,
    mask_schedule,
    mask_schedule_type,
):
    logger.info("Evaluating T2I...")
    model.eval()
    # vq_model.eval()

    with open(config.experiment.eval_t2i_file, "r") as file:
        test_captions = json.load(file)
    validation_prompts = [caption["caption"] for caption in test_captions]

    if hasattr(model, 'module'):
        mask_dtype = model.module.Face.model.embed_tokens.weight.dtype
    else:
        mask_dtype = model.Face.model.embed_tokens.weight.dtype

    mask_token_id = config.model.Face.vocab_size - 1

    image_tokens = torch.ones((len(test_captions), config.model.Face.num_vq_tokens), dtype=torch.long,
                            device=accelerator.device) * mask_token_id

    input_ids, _ = uni_prompting((validation_prompts, image_tokens), 't2i_gen')

    if config.training.guidance_scale > 0:
        uncond_input_ids, _ = uni_prompting(([''] * len(validation_prompts), image_tokens), 't2i_gen')
        attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
                                                            pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
                                                            soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
                                                            eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
                                                            rm_pad_in_image=True).to(mask_dtype)
    else:
        attention_mask = create_attention_mask_predict_next(input_ids,
                                                            pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
                                                            soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
                                                            eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
                                                            rm_pad_in_image=True).to(mask_dtype)
        uncond_input_ids = None

    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    else:
        weight_dtype = torch.float32

    with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"):
        # Generate images
        gen_token_ids = accelerator.unwrap_model(model).t2i_generate(
            input_ids=input_ids,
            uncond_input_ids=uncond_input_ids,
            attention_mask=attention_mask,
            guidance_scale=config.training.guidance_scale,
            temperature=config.training.get("generation_temperature", 1.0),
            timesteps=config.training.generation_timesteps,
            noise_schedule=mask_schedule,
            noise_type=config.training.get("noise_type", "mask"),
            predict_all_tokens=config.training.get("predict_all_tokens", False),
            seq_len=config.model.Face.num_vq_tokens,
            uni_prompting=uni_prompting,
            config=config,
            mask_schedule_type=mask_schedule_type,
            t2i_loss_type=config.training.get("t2i_loss_type", "cross_entropy"),
        )

    gen_token_ids = torch.clamp(gen_token_ids, max=accelerator.unwrap_model(model).config.codebook_size - 1, min=0)
    predicted_images = vq_model.decode_code(gen_token_ids)

    if config.training.get("pre_encode", False):
        del vq_model

    predicted_images = torch.clamp((predicted_images + 1.0) / 2.0, min=0.0, max=1.0)
    predicted_images *= 255.0
    predicted_images = predicted_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)

    origin_images = []
    rec_images = []
    for caption in test_captions:
        if caption["image"] == "None":
            origin_images.append(np.zeros((256, 256, 3), dtype=np.uint8))
            rec_images.append(np.zeros((256, 256, 3), dtype=np.uint8))
        else:
            origin_images.append(
                np.array(
                    transforms.Resize(
                        256, 
                        interpolation=transforms.InterpolationMode.BICUBIC
                    )(Image.open(caption["image"]).convert("RGB"))
                )
            )

            rec_image = image_transform(Image.open(caption["image"]).convert("RGB"))
            rec_image = vq_model(rec_image.unsqueeze(0).to(accelerator.device))
            rec_image = torch.clamp((rec_image.squeeze() + 1.0) / 2.0, min=0.0, max=1.0)
            rec_image *= 255.0
            rec_image = rec_image.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
            rec_images.append(rec_image)

    origin_images = np.stack(origin_images, axis=0)
    rec_images = np.stack(rec_images, axis=0)

    images = np.concatenate([origin_images, rec_images, predicted_images], 2)
    pil_images = [Image.fromarray(image) for image in images]

    # Log images
    wandb_images = [wandb.Image(image, caption=validation_prompts[i]) for i, image in enumerate(pil_images)]
    wandb.log({"Generated images": wandb_images}, step=global_step)

    model.train()
    # vq_model.train()



@torch.no_grad()
def eval_mmu(
    model,
    vq_model,
    uni_prompting,
    accelerator,
    config,
    global_step,
    logger,
    **kwargs,
):
    logger.info("Evaluating MMU...")
    model.eval()
    # vq_model.eval()

    device = accelerator.device

    with open(config.experiment.eval_mmu_file, "r") as file:
        test_captions = json.load(file)

    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    else:
        weight_dtype = torch.float32

    ids = []
    images = []
    conversations = []
    for caption in test_captions:
        ids.append(caption['id'])
        images.append(caption['image'])
        conversations.append(caption['conversations'])

    responses = []
    curr_question = None
    curr_response = None
    for i, (id, img, convs) in enumerate(tqdm(zip(ids, images, conversations))):
        response = {}
        response["conversation"] = []

        image_ori = Image.open(img).convert("RGB")
        response["image"] = Image.fromarray(np.array(transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC)(image_ori)))
        image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device)
        image = image.unsqueeze(0)
        image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)

        # Load Clip and Face Features
        if kwargs.get("clip_features_dir", None) is not None:
            clip_feature = np.load(f"{os.path.join(kwargs['clip_features_dir'], id)}.npy")
            clip_feature = torch.tensor(clip_feature).to(device)
            clip_feature = clip_feature.unsqueeze(0)
        else:
            clip_feature = None
        
        if kwargs.get("face_features_dir", None) is not None:
            face_feature = np.load(f"{os.path.join(kwargs['face_features_dir'], id)}.npy")
            face_feature = torch.tensor(face_feature).to(device)
            face_feature = face_feature.unsqueeze(0)
        else:
            face_feature = None

        for conv in convs:
            if conv["from"] == "human":
                question = conv["value"]
                input_ids = uni_prompting.text_tokenizer(['USER: \n' + question + ' ASSISTANT:'])['input_ids']
                input_ids = torch.tensor(input_ids).to(device)

                input_ids = torch.cat([
                    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device),
                    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
                    image_tokens,
                    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
                    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device),
                    input_ids
                ], dim=1).long()

                attention_mask = create_attention_mask_for_mmu(input_ids.to(device),
                                                            eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']))

                with torch.autocast("cuda", dtype=weight_dtype, enabled=accelerator.mixed_precision != "no"):
                    cont_toks_list = model.mmu_generate(
                        input_ids, 
                        attention_mask=attention_mask,
                        max_new_tokens=config.experiment.eval_mmu_max_new_tokens, 
                        top_k=1,
                        eot_token=uni_prompting.sptids_dict['<|eot|>'],
                        temperature=0.8,
                        clip_feature=clip_feature,
                        face_feature=face_feature,
                    )

                cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
                text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)[0]
                curr_question = question
                curr_response = text
            elif conv["from"] == "gpt":
                answer_gpt = conv["value"]
                response["conversation"].append({
                    'human': curr_question,
                    'Face': curr_response,
                    'gpt': answer_gpt,
                })

        responses.append(response)
    
    # table = wandb.Table(columns=["Conversation"])
    # for res in responses:
    #     table.add_data(
    #         json.dumps(res["conversation"], indent=4),
    #     )
    wandb_images = [wandb.Image(res["image"], caption=json.dumps(res["conversation"], indent=4)) for res in responses]
    wandb.log({"MMU_Images": wandb_images}, step=global_step)
    # wandb.log({"MMU_Convs": table}, step=global_step)

    model.train()
    # vq_model.train()
            


@torch.no_grad()
def infer_t2i(
    model,
    vq_model,
    uni_prompting,
    config,
    mask_schedule,
    device,
    batch_size,
    output_dir,
    num=100000,
):
    model.eval()

    with open(config.experiment.eval_t2i_file, "r") as file:
        test_captions = json.load(file)
    validation_ids = [caption["id"] for caption in test_captions]
    validation_prompts = [caption["caption"] for caption in test_captions]

    if hasattr(model, 'module'):
        mask_dtype = model.module.Face.model.embed_tokens.weight.dtype
    else:
        mask_dtype = model.Face.model.embed_tokens.weight.dtype

    mask_token_id = config.model.Face.vocab_size - 1

    count = 0
    infer_dataset = []
    for i in range(0, len(validation_prompts), batch_size):
        if count >= num:
            break

        if count + batch_size > num:
            current_num = num - count
        else:
            current_num = batch_size

        if i + current_num > len(validation_prompts):
            current_num = len(validation_prompts) - i

        print(f"Generating images for batch {i} to {i + current_num}")

        ids = validation_ids[i:i + current_num]
        prompts = validation_prompts[i:i + current_num]

        image_tokens = torch.ones((len(ids), config.model.Face.num_vq_tokens), dtype=torch.long,
                                device=device) * mask_token_id

        input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen')

        if config.training.guidance_scale > 0:
            uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen')
            attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0),
                                                                pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
                                                                soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
                                                                eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
                                                                rm_pad_in_image=True).to(mask_dtype)
        else:
            attention_mask = create_attention_mask_predict_next(input_ids,
                                                                pad_id=int(uni_prompting.sptids_dict['<|pad|>']),
                                                                soi_id=int(uni_prompting.sptids_dict['<|soi|>']),
                                                                eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']),
                                                                rm_pad_in_image=True).to(mask_dtype)
            uncond_input_ids = None


        gen_token_ids = model.t2i_generate(
            input_ids=input_ids,
            uncond_input_ids=uncond_input_ids,
            attention_mask=attention_mask,
            guidance_scale=config.training.guidance_scale,
            temperature=config.training.get("generation_temperature", 1.0),
            timesteps=config.training.generation_timesteps,
            noise_schedule=mask_schedule,
            noise_type=config.training.get("noise_type", "mask"),
            predict_all_tokens=config.training.get("predict_all_tokens", False),
            seq_len=config.model.Face.num_vq_tokens,
            uni_prompting=uni_prompting,
            config=config,
            mask_schedule_type=config.training.mask_schedule,
            t2i_loss_type=config.training.t2i_loss_type,
        )

        gen_token_ids = torch.clamp(gen_token_ids, max=model.config.codebook_size - 1, min=0)
        predicted_images = vq_model.decode_code(gen_token_ids)

        if config.training.get("pre_encode", False):
            del vq_model

        predicted_images = torch.clamp((predicted_images + 1.0) / 2.0, min=0.0, max=1.0)
        predicted_images *= 255.0
        predicted_images = predicted_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)

        for j in range(len(ids)):
            image = Image.fromarray(predicted_images[j])
            image.save(os.path.join(output_dir, f"{ids[j]}.png"))
        
        count += len(ids)

        for j in range(len(ids)):
            infer_dataset.append({
                "images": [os.path.join(output_dir, f"{ids[j]}.png")],
                "texts": [prompts[j]],
            })
    
    # with open(os.path.join(output_dir, "infer_dataset.json"), "w") as file:
    #     json.dump(infer_dataset, file)


@torch.no_grad()
def infer_mmu(
    model,
    vq_model,
    uni_prompting,
    config,
    device,
    output_dir,
    num=100000,
):
    model.eval()
    with open(config.experiment.eval_mmu_file, "r") as file:
        test_captions = json.load(file)

    ids = []
    images = []
    conversations = []
    clip_features = []
    face_features = []
    count = 0
    for caption in test_captions:
        ids.append(caption['id'])
        images.append(caption['image'])
        conversations.append(caption['conversations'])

        clip_feature = np.load(f"{os.path.join(config.dataset.params.clip_features_dir, caption['id'])}.npy")
        clip_features.append(clip_feature)
        face_feature = np.load(f"{os.path.join(config.dataset.params.face_features_dir, caption['id'])}.npy")
        face_features.append(face_feature)

        count += 1
        if count >= num:
            break
        

    question_id_captions = 0
    caption_gts = []
    question_id_convs = 0
    conv_gts = []
    for i, (id, img, convs, clip_feature, face_feature) in enumerate(tqdm(zip(ids, images, conversations, clip_features, face_features))):
        dataset_type = None
        for dt in ['ffhq', 'CelebV', 'MM-CelebA']:
            if dt in id:
                dataset_type = dt
        new_image_path = f"{dataset_type}-{os.path.basename(img)}"

        image_ori = Image.open(img).convert("RGB")
        image = image_transform(image_ori, resolution=config.dataset.params.resolution).to(device)
        image = image.unsqueeze(0)
        image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)

        # Load Clip and Face Features
        
        clip_feature = torch.tensor(clip_feature).to(device)
        clip_feature = clip_feature.unsqueeze(0)
    
        face_feature = torch.tensor(face_feature).to(device)
        face_feature = face_feature.unsqueeze(0)

        for i in range(len(convs)):
            if i % 2 == 0:
                question = convs[i]["value"]
                input_ids = uni_prompting.text_tokenizer(['USER: \n' + question + ' ASSISTANT:'])['input_ids']
                input_ids = torch.tensor(input_ids).to(device)
                input_ids = torch.cat([
                    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|mmu|>']).to(device),
                    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
                    image_tokens,
                    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
                    (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|sot|>']).to(device),
                    input_ids
                ], dim=1).long()
                attention_mask = create_attention_mask_for_mmu(input_ids.to(device),
                                                        eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']))
                cont_toks_list = model.mmu_generate(
                    input_ids, 
                    attention_mask=attention_mask,
                    max_new_tokens=config.experiment.eval_mmu_max_new_tokens, 
                    top_k=1,
                    eot_token=uni_prompting.sptids_dict['<|eot|>'],
                    temperature=0.8,
                    clip_feature=clip_feature,
                    face_feature=face_feature,
                )
                cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
                text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list, skip_special_tokens=True)[0]
                if i + 1 == len(convs) - 1: # caption
                    caption_gt = {
                        "question_id": question_id_captions,
                        "image": new_image_path,
                        "text": question,
                        "answer": text,
                        "category": "caption"
                    }
                    caption_gts.append(caption_gt)
                    question_id_captions += 1
                else: # conv
                    conv_gt = {
                        "question_id": question_id_convs,
                        "image": new_image_path,
                        "text": question,
                        "answer": text,
                        "category": "conv",
                    }

                    conv_gts.append(conv_gt)
                    question_id_convs += 1
            else:
                continue

    captions_gt_file = os.path.join(output_dir, "captions_gt.jsonl")
    with open(captions_gt_file, 'w') as f:
        for caption_gt in caption_gts:
            json.dump(caption_gt, f)
            f.write('\n')

    convs_gt_file = os.path.join(output_dir, "convs_gt.jsonl")
    with open(convs_gt_file, 'w') as f:
        for conv_gt in conv_gts:
            json.dump(conv_gt, f)
            f.write('\n')
