import os
import os.path as osp
import sys
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
sys.path.append(osp.dirname(osp.abspath(__file__)))
from datetime import datetime
import debugpy
import hydra
import json
from alive_progress import alive_bar, alive_it
from omegaconf import DictConfig, OmegaConf
import torch
import random
import numpy as np

from .model.pipelines import StableDiffusion3Pipeline
from .model.transformers.transformer_sd3 import SD3Transformer2DModel

from logic import Tora

from gear.logger import clogger
from gear.configuration import SafeOmegaDict

def fix_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True
    torch.backends.cudnn.benchmark = False
    random.seed(seed)
    np.random.seed(seed)

def makedir(path):
    os.makedirs(path, exist_ok=True)
    return path

# set config_path
CONFIG_PATH = f"{os.getcwd()}/A_exps"
print("\n-*-*-")
print(f"Checking config_path: {CONFIG_PATH}")
print(f"Config path exists: {os.path.exists(CONFIG_PATH)}")
print("-*-*-\n")

@hydra.main(config_path=CONFIG_PATH, version_base=None)
def main(cfgs: DictConfig):
    print(OmegaConf.to_yaml(cfgs))
    cfgs = SafeOmegaDict(cfgs)
    
    # Set default device if not provided
    if not hasattr(cfgs, "device"):
        cfgs.device = 0
    
    if getattr(cfgs, "debug", False):
        ADDRESS = "0.0.0.0"
        PORT = 7737
        debugpy.listen((ADDRESS, PORT))
        print(f"Waiting for debugger to attach on {ADDRESS}:{PORT}...")
        debugpy.wait_for_client()
        print("Debugger attached")
        
    # load data
    TIMESTAMP = datetime.now().strftime("%m%d_%H%M%S%f")[:-4]
    data = []
    file_mode = "a"
    with open(cfgs.dir_prompt, "r") as f:
        for line in f:
            data.append(json.loads(line))

    # prompt pca
    if cfgs.on_prompt_pca:
        Tora.activate()
        clogger.info(f"Tora activated")

    # Assign transformer and pipeline
    transformer = SD3Transformer2DModel.from_pretrained(
        osp.join(cfgs.dir_checkpoint, cfgs.model_name),
        subfolder = "transformer",
        torch_dtype=torch.float16,
    )
    pipe = StableDiffusion3Pipeline.from_pretrained(
        osp.join(cfgs.dir_checkpoint, cfgs.model_name), 
        torch_dtype=torch.float16,
        transformer = transformer,
    )
    pipe = pipe.to(device=f"cuda:{cfgs.device}")
    
    # Setup output directories
    BIG_EXP_NAME = cfgs.exp_name.rsplit("-", 1)[0]
    SMALL_EXP_NAME = cfgs.exp_name.rsplit("-", 1)[1]
    
    for d in alive_it(data):
        
        # Set seed for reproducibility
        fix_seed(d["seed"])
        PROMPT = d["prompt"]

        short_prompt = d['prompt'].replace(" ", "_")

        image = pipe(
            PROMPT,
            negative_prompt=None,
            num_inference_steps=cfgs.inference_opt.num_steps,
            guidance_scale=cfgs.inference_opt.guidance_scale,
        ).images[0]

        path_output_image = makedir(osp.join(cfgs.dir_output, BIG_EXP_NAME, "images"))
        short_prompt = d['prompt'].replace(" ", "_")
        category = d.get("category", None)
        image.save(osp.join(path_output_image, f"{SMALL_EXP_NAME}-{d['id']}-{short_prompt}-{category}-{TIMESTAMP}.png"))
        
        output_data = {
            "id": d["id"],
            "seed": d["seed"],
            "prompt": d["prompt"],
            "image": f"{SMALL_EXP_NAME}-{d['id']}-{short_prompt}-{category}-{TIMESTAMP}.png", 
        }
        
        path_metadata = makedir(osp.join(cfgs.dir_output, BIG_EXP_NAME, "metadata"))
        with open(osp.join(path_metadata, f"metadata_{TIMESTAMP}.jsonl"), file_mode, encoding="utf-8") as f:
            json.dump(output_data, f, ensure_ascii=False)
            f.write("\n")

        torch.cuda.empty_cache()
    

if __name__ == "__main__":
    
    main()