import os
import os.path as osp
import sys
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
sys.path.append(osp.dirname(osp.abspath(__file__)))
import random
import numpy as np
import torch
import debugpy
import json
from omegaconf import DictConfig, OmegaConf
from datetime import datetime
import hydra
from gear.logger import clogger

from .model.pipelines import FluxPipeline
from .model.transformers import FluxTransformer2DModel

from logic import Tora, Ops

from alive_progress import alive_it

def fix_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # for multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True
    torch.backends.cudnn.benchmark = False
    random.seed(seed)
    np.random.seed(seed)

def makedir(path):
    os.makedirs(path, exist_ok=True)
    return path

# set config_path
CONFIG_PATH = f"{os.getcwd()}/A_exps"
print("\n-*-*-")
print(f"Checking config_path: {CONFIG_PATH}")
print(f"Config path exists: {os.path.exists(CONFIG_PATH)}")
print("-*-*-\n")

@hydra.main(config_path=CONFIG_PATH, version_base=None)
def main(cfgs: DictConfig):
    print(OmegaConf.to_yaml(cfgs))
    TIMESTAMP = datetime.now().strftime("%m%d_%H%M%S%f")[:-4]
    
    # Set default device if not provided
    if not hasattr(cfgs, "device"):
        cfgs.device = 0
    
    if getattr(cfgs, "debug", False):
        ADDRESS = "0.0.0.0"
        PORT = 7738
        clogger.info(f"Waiting for debugger to attach on {ADDRESS}:{PORT}...")
        debugpy.listen((ADDRESS, PORT))
        debugpy.wait_for_client()
        clogger.info("Debugger attached")
    
    # load data
    data = []
    file_mode = "a"
    with open(cfgs.dir_prompt, "r") as f:
        for line in f:
            data.append(json.loads(line))
    
    # model and experiment Ops
    Ops.set_attention_type(cfgs.attention_type)    
    
    # prompt pca
    if cfgs.on_prompt_pca:
        Tora.activate()
        clogger.info(f"Tora activated")
    
    # Assign transformer and pipeline
    transformer = FluxTransformer2DModel.from_pretrained(
        osp.join(cfgs.dir_checkpoint, cfgs.model_name),
        subfolder = "transformer",
        torch_dtype=torch.float16,
    )
    pipe = FluxPipeline.from_pretrained(
        osp.join(cfgs.dir_checkpoint, cfgs.model_name), 
        torch_dtype=torch.float16,
        transformer=transformer,
    )
    pipe = pipe.to(device=f"cuda:{cfgs.device}")
    
    # Setup output directories
    BIG_EXP_NAME = cfgs.exp_name.rsplit("-", 1)[0]
    SMALL_EXP_NAME = cfgs.exp_name.rsplit("-", 1)[1]
    
    for d in alive_it(data):
        
        # Set seed for reproducibility
        fix_seed(d["seed"])
        PROMPT = d["prompt"]

        image = pipe(
            PROMPT,
            negative_prompt=None,
            num_inference_steps=cfgs.inference_opt.num_steps,
            guidance_scale=cfgs.inference_opt.guidance_scale,
            max_sequence_length=cfgs.inference_opt.max_sequence_length,
        ).images[1]
        
        path_output_image = makedir(osp.join(cfgs.dir_output, BIG_EXP_NAME, "images"))
        short_prompt = d['prompt'].replace(" ", "_")
        category = d.get("category", None)
        image.save(osp.join(path_output_image, f"{SMALL_EXP_NAME}-{d['id']}-{short_prompt}-{category}-{TIMESTAMP}.png"))
        
        output_data = {
            "id": d["id"],
            "seed": d["seed"],
            "prompt": d["prompt"],
            "image": f"{SMALL_EXP_NAME}-{d['id']}-{short_prompt}-{category}-{TIMESTAMP}.png"
        }
        
        path_metadata = makedir(osp.join(cfgs.dir_output, BIG_EXP_NAME, "metadata"))
        with open(osp.join(path_metadata, f"metadata_{TIMESTAMP}.jsonl"), file_mode, encoding="utf-8") as f:
            json.dump(output_data, f, ensure_ascii=False)
            f.write("\n")

        torch.cuda.empty_cache()
    