import torch
from huggingface_hub import hf_hub_download, upload_file
from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL
from safetensors.torch import load_file
from PIL import Image
from diffusers.models.autoencoders.vae import Encoder
from shape2stress_model import FrameTopoStressCNNSimple, FrameTopoStressViT
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import numpy as np
import os
import json
from torchvision import models, transforms
from datetime import datetime
import pandas as pd

model_folder = "metamaterial_shape_original"

### PRETRAINED MODEL IMPORT
print("\n\nMetamaterial shape model\n\n")
pipe = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        variant="fp16",
).to("cuda")
lora_path = f'/path/to/wd/models/{model_folder}/metamaterial_shape_full_model'
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embedding_path = f"/path/to/wd/models/{model_folder}/metamaterial_shape_full_model/metamaterial_shape_model_emb.safetensors"
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)

### MODEL SETTINGS
num_images = 1
pipe.encoder_train = True
prompt = f"metamaterial"
method = "c" # "o":original model, "c":constraint model
csv_target_path = "/path/to/wd/metamaterial/target.csv" # Target curve
pipe.starting_step = 23 # step at which the constraint starts
pipe.phase = 2
pipe.framework = "metamaterial"
pipe.abaqus_tmp_dir = "/path/to/Metamaterial/Abaqus_Temp"

# Model preparation
pipe.current_date  = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
pipe.constraint = True if method == "c" else False
pipe.early_prj_stop_flag = False

# Original model initialization
if not pipe.constraint:
    print(f"Model: Original")
    pipe.folder_name = f"run_{pipe.current_date}_original"
    saving_mode = "L"

# Constraint model initialization
if pipe.constraint:
    print(f"Model: constraint (our model)")
    pipe.folder_name = f"run_{pipe.current_date}_projected"
    saving_mode = "L"
    pipe.stress_strain_target = torch.tensor(pd.read_csv(csv_target_path, header=None).iloc[0].values, dtype=torch.float32).to("cuda:0")

### UTILS
def save_tensor_as_image(tensor, file_path, saving_mode='L'):

    # Convert the tensor to a NumPy array
    tensor = tensor.cpu().detach().numpy()
    
    # Normalize the tensor to the range [0, 255] if necessary
    tensor = np.clip(tensor, -1, 1)
    tensor += 1
    tensor *= 0.5
    
    # Convert to 8-bit (0-255)
    tensor = (tensor * 255).astype(np.uint8)
    
    image = Image.fromarray(tensor, mode=saving_mode)

    # Save the image
    image.save(file_path)


### MAIN

images = []
new_ds_metadata = {}
dir_path = f"/path/to/Metamaterial/Results/{pipe.folder_name}/images"

for i in range(num_images):

    print(f"\nImage n {i+1}")
    pipe.image_count = i+1
    pipe.early_prj_stop_flag = False

    # constraint model
    if pipe.constraint:
        file_name = f"{i+1}_projected.png"
    
    # Original model
    else:
        file_name = f"{i+1}_original.png"
    
    # Set up the path and check if already exists
    if not os.path.exists(dir_path): os.makedirs(dir_path)
    save_path = os.path.join(dir_path, file_name)
    if os.path.exists(save_path): 
        print(f"    Already exists")
        continue
    
    # Diffusion process
    pipe.first_cycle = True
    image = pipe(prompt=prompt, height = 128, width = 128, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}) #.images[0]

    # Save image
    image = torch.mean(image, dim=1, keepdim=True)
    save_tensor_as_image(image.squeeze(), save_path, saving_mode) 
