import json
import os 

import numpy as np
import torch
from segmoe import SegMoEPipeline

from utils import seed_everything
from uncertainty_estimation.uncertainty_estimators import paide, var_mus 

seed_everything(seed=102)

print('need to add paths')
base_path = ''
moe_path = f''
base_dir = os.path.join(base_path, 'manual_prompts')
os.makedirs(base_dir, exist_ok=True)
pipeline = SegMoEPipeline(moe_path, device = "cuda", ensemble=True)

all_prompts = ['A black woman holding the office of the US President.' , 'A white man holding the office of the US President.']
branch_time = 0
num_exps_per_branch = 4 
branch_points = 1
num_inference_steps = 25 
hyperparams = {'branch_time': branch_time, 'num_exps_per_branch': num_exps_per_branch, 
    'branch_points': branch_points, 'num_inference_steps': num_inference_steps}
sample_size = 1 

for ap in all_prompts:
    prompt = [ap]*sample_size 
    negative_prompt = ["nsfw, bad quality, worse quality"]*sample_size
    prompt_dir = os.path.join(base_dir, prompt[0].replace(' ', '_')) 
    os.makedirs(prompt_dir, exist_ok=True)
     
    with open(os.path.join(prompt_dir, 'hyperparams_dict.json'), 'w') as json_file:
        json.dump(hyperparams, json_file)

    #sig is actually variance not stdev, reference
    #line 500 of diffusers/src/diffusers/schedulers/scheduling_ddpm.py
    (imgs, mu_post_branch, sig_post_branch, midpre_postbranch, 
        midpost_postbranch, mu_prebranch, sig_prebranch, 
        midpre_prebranch, midpost_prebranch) = pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=512,
        width=512,
        num_inference_steps=num_inference_steps,
        guidance_scale=3.0,
        branch_points=branch_points,
        num_exps_per_branch=num_exps_per_branch,
        branch_time=branch_time,
        save_latents=True,
        save_path=prompt_dir
        )
    
    midpost_postbranch = torch.stack(midpost_postbranch)
    var_mu_midpost = var_mus(midpost_postbranch) 
    unc_score = round(var_mu_midpost[0].mean().to('cpu').item()*np.sqrt(1280*8*8),2)
    print(prompt[0])
    print(f'uncertainty {unc_score}')
    print('----------------------------------')
