import torch
from huggingface_hub import hf_hub_download, upload_file
from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL
from safetensors.torch import load_file
from PIL import Image
from diffusers.models.autoencoders.vae import Encoder
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from projection import Projection
import random
import numpy as np
import os
from torchvision import models, transforms

### PRETRAINED MODEL IMPORT
print("\n\nPorosity model\n\n")
pipe = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        variant="fp16",
).to("cuda")
lora_path = '/path/to/wd/models/porosity_original/microstructures_model'
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embedding_path = "/path/to/wd/models/porosity_original/microstructures_model/microstructures_model_emb.safetensors"
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)

### MODEL SETTINGS
num_images = 1
pipe.encoder_train = True
prompt = "microstructures" 
pipe.porosity = 0.4
method = "c" # "o":original model, "c":projected model
pipe.starting_step = 15 # constraint starting step
pipe.proj_normalization = False
pipe.threshold = 0.0
pipe.framework = "porosity"

pipe.constraint = True if method == "c" else False

# Original model initialization
if not pipe.constraint:
    print(f"Model: Original")

# constraint model initialization
if pipe.constraint:
    print(f"Model: Constraint (our model) | Porosity: {pipe.porosity}")

### UTILS
def convert_to_grayscale(image_tensor):
    weights = torch.tensor([0.2989, 0.5870, 0.1140], device=torch.device('cuda:0')).view(3, 1, 1)
    grayscale_image = (weights * image_tensor).sum(dim=0, keepdim=True)
    return grayscale_image


def save_tensor_as_image(tensor, file_path, saving_mode='RGB'):
    # Grayscale if necessary
    if saving_mode == 'L':
        tensor = convert_to_grayscale(tensor).squeeze()

    # Convert the tensor to a NumPy array
    tensor = tensor.cpu().detach().numpy()
    
    # Normalize the tensor to the range [0, 255] if necessary
    tensor = np.clip(tensor, -1, 1)
    if pipe.proj_normalization:
        tensor -= 1/3
        tensor *= .75
    tensor += 1
    tensor = (np.round(tensor * 255/2)).astype(np.uint8)

    
    # Create an image object from the NumPy array
    if saving_mode == 'RGB':
        tensor = np.transpose(tensor, (1, 2, 0))

    image = Image.fromarray(tensor, mode=saving_mode)

    # Save the image
    image.save(file_path)



### MAIN

images = []

for i in range(num_images):

    print(f"\nImage n {i+1}")
    
    # Constraint model
    if pipe.constraint:
        dir_path = f"/path/to/Porosity/Results/microstructures_projected_p{pipe.porosity}_last_{24-pipe.starting_step}"
        file_name = f"{i+1}_projected_p{pipe.porosity}.png"
        saving_mode = "RGB"
    
    # Original model
    else:
        dir_path = f"/path/to/Porosity/Results/microstructures_original"
        file_name = f"{i+1}_original.png"
        saving_mode = "RGB"
    
    # Set path and check if already exist
    save_path = os.path.join(dir_path, file_name)
    if os.path.exists(save_path): 
        print(f"    Already exists")
        continue

    # Diffusion process
    image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}) #.images[0]
    
    # Print image porosity
    print(f"    Image porosity: {((image) < (pipe.threshold)).float().mean()}")

    # Save image
    if not os.path.exists(dir_path): os.makedirs(dir_path)
    save_tensor_as_image(image.squeeze(), save_path, saving_mode) 
