import sys
import os
project_root = os.path.join(os.path.dirname(__file__), "../..")
project_root = os.path.abspath(project_root)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler, StableDiffusionXLPipeline, EulerDiscreteScheduler, DiffusionPipeline, LCMScheduler, AutoPipelineForText2Image
from main.coco_eval.coco_evaluator import evaluate_model, compute_clip_score, compute_image_reward
from main.sdxl.sdxl_text_encoder import SDXLTextEncoder
from accelerate.utils import ProjectConfiguration
from huggingface_hub import hf_hub_download
from accelerate.logging import get_logger
from main.utils import create_image_grid
from safetensors.torch import load_file
from main.utils import SDTextDataset
from transformers import AutoTokenizer
from accelerate.utils import set_seed
from accelerate import Accelerator
from peft import LoraConfig
from tqdm import tqdm 
import numpy as np 
import argparse 
import logging 
import wandb 
import torch 
import glob 
import time 
import os 
import accelerate

logger = get_logger(__name__, log_level="INFO")

def create_generator(checkpoint_path, base_model=None, args=None):
    if base_model is None:
        generator = UNet2DConditionModel.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            subfolder="unet"
        ).float()
        generator.requires_grad_(False)

        if args.generator_lora:
            lora_target_modules = [
                "to_q",
                "to_k",
                "to_v",
                "to_out.0",
                "proj_in",
                "proj_out",
                "ff.net.0.proj",
                "ff.net.2",
                "conv1",
                "conv2",
                "conv_shortcut",
                "downsamplers.0.conv",
                "upsamplers.0.conv",
                "time_emb_proj",
            ]
            lora_config = LoraConfig(
                r=args.lora_rank,
                target_modules=lora_target_modules,
                lora_alpha=args.lora_alpha,
                lora_dropout=args.lora_dropout
            )
            generator.add_adapter(lora_config) 
    else:
        generator = base_model

    state_dict = torch.load(checkpoint_path, map_location="cpu")
    print(generator.load_state_dict(state_dict, strict=True))

    return generator 

def build_condition_input(resolution, accelerator):
    original_size = (resolution, resolution)
    target_size = (resolution, resolution)
    crop_top_left = (0, 0)

    add_time_ids = list(original_size + crop_top_left + target_size)
    add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=torch.float32)
    return add_time_ids

def get_x0_from_noise(sample, model_output, timestep, alphas_cumprod):
    alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
    # 0.0047 corresponds to the alphas_cumprod of the last timestep (999)
    # alpha_prod_t = (torch.ones_like(timestep) * 0.0047).reshape(-1, 1, 1, 1).double() 
    beta_prod_t = 1 - alpha_prod_t

    pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
    return pred_original_sample

@torch.no_grad()
def sample(
        noise, unet_added_conditions, model, vae, noise_scheduler, prompt_embed,
        device="cuda", num_step=1, conditioning_timestep=999
    ):
    alphas_cumprod = noise_scheduler.alphas_cumprod.to(device)

    if num_step == 1:
        all_timesteps = [conditioning_timestep]
        step_interval = 0  
    elif num_step == 2:
        all_timesteps = [999, 499]
        step_interval = 500 
    elif num_step == 4:
        all_timesteps = [999, 749, 499, 249]
        step_interval = 250 
    else:
        raise NotImplementedError()
    
    DTYPE = prompt_embed.dtype
    
    for constant in all_timesteps:
        current_timesteps = torch.ones(len(prompt_embed), device=device, dtype=torch.long)  *constant
        eval_images = model(
            noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions
        ).sample

        eval_images = get_x0_from_noise(
            noise, eval_images, current_timesteps, alphas_cumprod
        ).float()

        next_timestep = current_timesteps - step_interval 
        noise = noise_scheduler.add_noise(
            eval_images, torch.randn_like(eval_images), next_timestep
        ).to(DTYPE)  

    eval_images = vae.decode(eval_images / vae.config.scaling_factor, return_dict=False)[0]
    eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
    return eval_images 

@torch.no_grad()
def evaluate():
    torch.set_grad_enabled(False)

    parser = argparse.ArgumentParser() 
    # parser.add_argument("--folder", type=str, required=True, help="pass to folder list")
    parser.add_argument("--wandb_entity", type=str)
    parser.add_argument("--wandb_project", type=str)
    parser.add_argument("--wandb_name", type=str)
    parser.add_argument("--latent_resolution", type=int, default=128)
    parser.add_argument("--image_resolution", type=int, default=1024)
    parser.add_argument("--num_train_timesteps", type=int, default=1000)
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--eval_batch_size", type=int, default=4)
    parser.add_argument("--conditioning_timestep", type=int, default=999)
    parser.add_argument("--eval_res", type=int, default=256)
    parser.add_argument("--ref_dir", type=str)
    parser.add_argument("--total_eval_samples", type=int, default=30000)
    parser.add_argument("--per_image_object", type=int, default=9)
    parser.add_argument("--test_visual_batch_size", type=int, default=81)
    parser.add_argument("--predict_x0", action="store_true")
    parser.add_argument("--anno_path", type=str)
    parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-xl-base-1.0")
    parser.add_argument("--revision", type=str)
    parser.add_argument("--sdxl_lightning_4step", action="store_true")
    parser.add_argument("--sdxl_lightning_1step", action="store_true")
    parser.add_argument("--clip_score", action="store_true")
    parser.add_argument("--sdxl_teacher", action="store_true")
    parser.add_argument("--num_step", type=int, default=1)
    parser.add_argument("--lcm_1step", action="store_true")
    parser.add_argument("--lcm_4step", action="store_true")
    parser.add_argument("--turbo_1step", action="store_true")
    parser.add_argument("--turbo_4step", action="store_true")
    parser.add_argument("--image_reward", action="store_true")
    parser.add_argument("--pick_score", action="store_true")
    parser.add_argument("--checkpoint_path", type=str, help="specify a single checkpoint instead of a folder")
    parser.add_argument("--guidance_scale", type=float, default=6)
    parser.add_argument("--result_path", type=str)
    parser.add_argument("--generator_lora", action="store_true")
    parser.add_argument("--lora_rank", type=int, default=64)
    parser.add_argument("--lora_alpha", type=float, default=8)
    parser.add_argument("--lora_dropout", type=float, default=0.0)

    args = parser.parse_args()

    accelerator = Accelerator(
        gradient_accumulation_steps=1,
        mixed_precision="no",
        log_with="wandb",
    )

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True 

    if accelerator.is_main_process:
        run = wandb.init(config=args, **{"mode": "online", "entity": args.wandb_entity, "project": args.wandb_project})
        wandb.run.name = args.wandb_name 
        
    text_encoder = SDXLTextEncoder(args, accelerator)
    tokenizer_one = AutoTokenizer.from_pretrained(
        args.model_id, subfolder="tokenizer", revision=args.revision, use_fast=False
    )

    tokenizer_two = AutoTokenizer.from_pretrained(
        args.model_id, subfolder="tokenizer", revision=args.revision, use_fast=False
    )
    
    dataset = SDTextDataset(
        anno_path=args.anno_path,
        tokenizer_one=tokenizer_one,
        tokenizer_two=tokenizer_two,
        is_sdxl=True
    )
    
    total_samples = min(len(dataset), args.total_eval_samples)
    num_processes = accelerator.num_processes
    samples_per_process = total_samples // num_processes
    start_idx = accelerator.process_index * samples_per_process
    end_idx = start_idx + samples_per_process
    if accelerator.process_index == num_processes - 1:
        end_idx = total_samples 
    subset_indices = list(range(start_idx, end_idx))
    subset_dataset = torch.utils.data.Subset(dataset, subset_indices)

    

    dataloader = torch.utils.data.DataLoader(
        subset_dataset, batch_size=args.eval_batch_size, 
        shuffle=False, drop_last=False, num_workers=8
    ) 
    base_add_time_ids = build_condition_input(args.image_resolution, accelerator)

    vae = AutoencoderKL.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", 
        subfolder="vae"
    ).to(accelerator.device).float()

    scheduler = DDIMScheduler.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        subfolder="scheduler"
    )

    generator = None

    checkpoint = args.checkpoint_path
    
    print(f"Evaluating {checkpoint}")
    model_index = int(checkpoint.replace("/", "").split("_")[-1]) 

    generator = create_generator(
        os.path.join(checkpoint, "pytorch_model.bin"), 
        base_model=generator,
        args=args
    )

    generator = generator.to(accelerator.device)

    set_seed(args.seed+accelerator.process_index)

    all_images = [] 
    all_captions = [] 

    for i, data in tqdm(enumerate(dataloader), disable=not accelerator.is_main_process, total=len(dataset) // args.eval_batch_size):
        all_captions.append(data["key"])
        
        prompt_embed, pooled_prompt_embed = text_encoder(data)

        noise = torch.randn(len(prompt_embed), 4, 
            args.latent_resolution, args.latent_resolution, 
            dtype=torch.float32,
            generator=torch.Generator().manual_seed(i)
        ).to(accelerator.device) 
        
        add_time_ids = base_add_time_ids.repeat(noise.shape[0], 1)

        unet_added_conditions = {
            "time_ids": add_time_ids,
            "text_embeds": pooled_prompt_embed
        }

        eval_images = sample(
                noise, unet_added_conditions, generator, vae, scheduler, prompt_embed,
                device=accelerator.device, num_step=args.num_step,
                conditioning_timestep=args.conditioning_timestep
            )

        all_images.append(eval_images.cpu().numpy())
        
    print(f"Process {accelerator.process_index} finished generation. Gathering results for evaluation...")
    accelerator.wait_for_everyone()
    
    all_images_gathered = accelerate.utils.gather_object(all_images)  # List of List[ndarray]
    all_captions_gathered = accelerate.utils.gather_object(all_captions)  # List of List[str]

    
    torch.cuda.empty_cache()

    if accelerator.is_main_process:
        
        all_images = [img for sublist in all_images_gathered for img in sublist]
        all_captions = [cap for sublist in all_captions_gathered for cap in sublist]

        all_images = np.concatenate([all_images], axis=0)[:args.total_eval_samples]
        all_captions = all_captions[:args.total_eval_samples]
        
        data_dict = {"all_images": all_images, "all_captions": all_captions}
        
        fid = evaluate_model(
            args, accelerator.device, data_dict["all_images"], patch_fid=False
        )
        patch_fid = evaluate_model(
            args, accelerator.device, data_dict["all_images"], patch_fid=True
        )
        print(f"checkpoint {checkpoint} fid {fid} patch fid {patch_fid}")

        wandb.log(
            {"fid": fid, "patch_fid": patch_fid},
            step=model_index
        )

        if args.clip_score:
            clip_score = compute_clip_score(
                images=data_dict["all_images"],
                captions=data_dict["all_captions"],
                clip_model="ViT-G/14",
                device=accelerator.device,
                how_many=args.total_eval_samples
            )
            print(f"checkpoint {checkpoint} clip score {clip_score}")
            wandb.log(
                {"clip_score": clip_score},
                step=model_index
            )

        if args.image_reward:
            image_reward = compute_image_reward(
                images=data_dict["all_images"],
                captions=data_dict["all_captions"],
                device=accelerator.device
            )
            print(f"checkpoint {checkpoint} image reward {image_reward}")
            wandb.log(
                {"image_reward": image_reward},
                step=model_index
            )


        visualize_images = all_images[:args.test_visual_batch_size]
    
        image_brightness = (visualize_images / 255.0).mean()
        image_std = (visualize_images / 255.0).std()

        wandb.log(
            {
                "image_brightness": image_brightness,
                "image_std": image_std
            },
            step=model_index
        )

        for start in range(0, len(visualize_images), args.per_image_object):
            end = min(start + args.per_image_object, len(visualize_images))

            if start >= end: 
                continue 
            
            eval_images_grid = create_image_grid(args, visualize_images[start:end], None)

            wandb.log(
                {f"generated_image_grid_{start:04d}_{end:04d}": wandb.Image(eval_images_grid)},
                step=model_index
            )


if __name__ == "__main__":
    evaluate()    