import os
import re
import sys
import torch
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from diffusers import StableDiffusionPipeline
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter  # 导入 TensorBoard

# 获取当前脚本所在目录的路径
current_dir = os.path.dirname(os.path.abspath(__file__))

# 获取上一级目录的路径
parent_dir = os.path.dirname(current_dir)

# 将上一级目录添加到 sys.path
sys.path.append(parent_dir)
from pre_exp.STF import stf_targets, pred_score_batch_eps

class OODDetector:
    def __init__(self, cfg, device):
        self.cfg = cfg
        self.device = device
        self.global_step = 0
        self.save = True
        self.image_size = cfg.image_size
        self.ref_guidance = None
        self.iterations = 5000
        self.test_ratio = 50
        self.prompt = "a DSLR photo of the Imperial State Crown of England."
        self.dtype = torch.float16
        # 生成当前时间的时间戳
        timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
        # 将时间戳添加到日志目录中
        log_dir = os.path.join(cfg.trial_dir, f'log_{timestamp}')
        self.writer = SummaryWriter(log_dir)# 初始化 TensorBoard 记录器

    def fill_as(self, inputs, target) -> torch.Tensor:
        return inputs.view((*inputs.size(), *[1] * (len(target.size()) - len(inputs.size()))))

    def add_noise(self, x, sigma):
        sigma = torch.ones(x.shape[0], device=x.device) * sigma
        noise = torch.normal(0, 1, x.size(), device=x.device)
        perturbed_x = x + noise * self.fill_as(sigma, x)
        return perturbed_x, sigma

    def get_diff(self, x, ref, prompt, timestep):
        x = x.to(self.device)
        t = torch.tensor([timestep] * x.shape[0], dtype=torch.long, device=self.device)
        noise = torch.randn_like(x)
        x_t = self.ref_guidance.scheduler.add_noise(x, noise, t)

        # 获取 text_embeddings
        text_embeddings = self.get_text_embeds(prompt)
        # print(f"text_embeddings shape: {text_embeddings.shape}")

        noise_pred = self.ref_guidance.unet(x_t.to(self.dtype), t.to(self.dtype), encoder_hidden_states=text_embeddings.to(self.dtype)).sample

        atbar = torch.tensor(
            [self.ref_guidance.scheduler.alphas_cumprod[timestep]] * x_t.shape[0],
            device=self.device,
        )
        _, eps, _ = pred_score_batch_eps(ref, x_t, atbar)
        diff = F.mse_loss(eps, noise_pred)
        return diff

    @torch.no_grad()
    def get_text_embeds(self, prompt):
        # 使用 ref_guidance 获取文本嵌入
        inputs = self.ref_guidance.tokenizer(prompt, padding='max_length', max_length=self.ref_guidance.tokenizer.model_max_length, truncation=True, return_tensors='pt')
        embeddings = self.ref_guidance.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings
    
    @torch.no_grad()
    def encode_imgs(self, imgs):
        target_dtype = imgs.dtype
        # imgs: [B, 3, H, W]
        imgs = 2 * imgs - 1

        posterior = self.ref_guidance.vae.encode(imgs.to(self.ref_guidance.vae.dtype)).latent_dist
        # kl_divergence = posterior.kl()

        latents = posterior.sample() * self.ref_guidance.vae.config.scaling_factor

        return latents.to(target_dtype) #, kl_divergence
    
    def load_base_lora(self):
        self.load_lora1 = True
        lora1_path = 'resources/backup/lora_ckpt_exp6/aki-000400.safetensors' #'/path/to/lora1/aki-000400.safetensors'
        print(f'[INFO] loading lora1: {lora1_path}')
        self.ref_guidance.load_lora_weights(lora1_path, adapter_name="lora1")
        self.ref_guidance.enable_xformers_memory_efficient_attention()
        self.ref_guidance = self.ref_guidance.to(self.device)
    
    def load_lora(self, iteration, load_strategy):
        lora2_path = 'resources/backup/lora2/aki-000200.safetensors' #'/path/to/lora2/aki-000200.safetensors'
        if self.cfg.load_lora:
            if iteration == 1:
                print(f'[INFO] loading lora2: {lora2_path}')
                self.ref_guidance.load_lora_weights(lora2_path, adapter_name="lora2")
            
            change_const = load_strategy
            if iteration <= 1000:
                weight = 1.0 - ((iteration - 1) / 1000) ** change_const
            else:
                weight = 0
            
            if self.load_lora1:
                self.ref_guidance.set_adapters(["lora1", "lora2"], adapter_weights=[1.0, weight])
            else:
                self.ref_guidance.set_adapters(["lora2"], adapter_weights=[weight])
        else:
            weight = 0
        return weight

    def detect_ood(self):
        print("detect ood!")

        timesteps = list(range(100, 901, 100))
        diff_gen_mean = []
        diff_gen_std = []

        lora_gen_means = {t: [] for t in timesteps}
        final_gen_means = []
        lora_names = []

        # # 定义排序键函数
        # def sort_key(filename):
        #     # 提取文件名中的数字部分
        #     number = int(filename.split('-')[1].split('.')[0])
        #     return number
        # # 对文件列表进行排序
        # sorted_files = sorted(os.listdir(self.cfg.lora_dir), key=sort_key)
        # # 打印排序后的文件列表
        # print(f"os.listdir(self.cfg.lora_dir): {sorted_files}")
        lora_name = "aki-000000.safetensors"
        prompt = self.prompt
        test_iters = [1] + [k * self.iterations // self.test_ratio for k in range(1, self.test_ratio)] + [self.iterations]
        for index, test_iter in enumerate(test_iters[:self.cfg.run_step_index], start=0):
            print(f"test_iter: {test_iter}")
            # lora_path = os.path.join(self.cfg.lora_dir, lora_name)
            self.ref_guidance = StableDiffusionPipeline.from_pretrained(
                "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16, local_files_only=True
            )
            # if lora_name != "aki-000000.safetensors":
            #     self.ref_guidance.load_lora_weights(lora_path, adapter_name="lora1")
            
            self.load_base_lora()
            weight = round(self.load_lora(test_iter, load_strategy=0.3), 3)
            print(f"round weight: {weight}")
            self.ref_guidance.to("cuda")
            # # 更新 ref_dir
            # ref_dir = os.path.join(self.cfg.ref_base_dir, os.path.splitext(lora_name)[0])
            ref_dir = os.path.join(self.cfg.ref_base_dir, str(weight))

            gen_temp_all = []
            for i in tqdm(range(len(timesteps)), position=0, desc="i", leave=False, colour='green', ncols=80):
                gen_temp = []
                for j in tqdm(range(5), position=1, desc="j", leave=False, colour='red', ncols=80):
                    gens, refs = self.forward(ref_dir, test_iter)
                    gens = gens.to(self.device)
                    refs = refs.to(self.device)
                    gens = self.encode_imgs(gens.to(self.dtype))
                    refs = self.encode_imgs(refs.to(self.dtype))

                    with torch.no_grad():
                        diff_gen = self.get_diff(gens, refs, [prompt for i in range(len(gens))], timesteps[i])
                    torch.cuda.empty_cache()
                    gen_temp.append(diff_gen.detach().cpu().item())
                    self.global_step += 1

                gen_mean = np.mean(gen_temp)
                gen_std = np.std(gen_temp)
                gen_temp_all.append(gen_mean)
                print(f'timesteps[{i}]: {timesteps[i]}, gen_mean: {gen_mean}, gen_std: {gen_std}')
                # 记录当前时间步的 gen_mean 和 gen_std 到 TensorBoard
                self.writer.add_scalars(f'gen_mean/{lora_name}', {str(test_iter): gen_mean}, timesteps[i])
                self.writer.add_scalars(f'gen_std/{lora_name}', {str(test_iter): gen_std}, timesteps[i])
                lora_gen_means[timesteps[i]].append(gen_mean)
                # self.writer.add_scalar(f'gen_mean/{lora_name}/run_step_{test_iter}', gen_mean, timesteps[i])
                # self.writer.add_scalar(f'gen_std/{lora_name}/run_step_{test_iter}', gen_std, timesteps[i])

            final_gen_mean = np.mean(gen_temp_all)
            final_gen_std = np.std(gen_temp_all)
            diff_gen_mean.append(final_gen_mean)
            final_gen_means.append(final_gen_mean)
            # lora_names.append(lora_name)
            print(f'Final gen_mean for {lora_name}, run step {test_iter}: {final_gen_mean}')
            # 记录最终的 gen_mean 到 TensorBoard
            self.writer.add_scalars(f'final_gen_mean_with_lora_step/{lora_name}', {str(test_iter): final_gen_mean}, index*100)
            self.writer.add_scalars(f'final_gen_std_with_lora_step/{lora_name}', {str(test_iter): final_gen_std}, index*100)
            # self.writer.add_scalar(f'final_gen_mean/{lora_name}/run_step_{test_iter}', final_gen_mean, index*100)
            # self.writer.add_scalar(f'final_gen_std/{lora_name}/run_step_{test_iter}', final_gen_std, index*100)

        print('done')
        # x = np.arange(len(timesteps))
        # bar_width = 0.35
        # plt.bar(x, diff_gen_mean, width=bar_width, label="Generated image", linewidth=2, capsize=5)
        # plt.xticks(x + bar_width / 2, timesteps)
        # plt.yscale("log")
        # plt.legend(loc='best')
        # plt.xlabel('Timesteps')
        # plt.ylabel('Mse loss of STF and Unet prediction')
        # plt.savefig(f'{self.cfg.trial_dir}/ood.png')
        # 绘制第一个折线图
        plt.figure()
        for test_iter in test_iters[:self.cfg.run_step_index]:
            means = [lora_gen_means[t][test_iters.index(test_iter)] for t in timesteps]
            plt.plot(timesteps, means, marker='o', label=f"{test_iter}_iter")
        plt.xlabel('Timesteps')
        plt.ylabel('gen_mean')
        plt.yscale("log")
        plt.legend(loc='best')
        plt.savefig(f'{self.cfg.trial_dir}/gen_mean_vs_timesteps.png')

        # 绘制第二个折线图
        plt.figure()
        # steps = [1]+[(i * 100) for i in range(1, self.cfg.run_step_index)]
        plt.plot(test_iters[:len(final_gen_means)], final_gen_means, marker='o', label='Final gen_mean')
        plt.xlabel('3D Gen Steps')
        plt.ylabel('final_gen_mean')
        plt.legend(loc='best')
        plt.savefig(f'{self.cfg.trial_dir}/final_gen_mean_vs_steps.png')

        # 保存数据
        with open(os.path.join(self.cfg.trial_dir, f'data.txt'), 'w') as f:
            f.write("Timesteps:\n")
            f.write(", ".join(map(str, timesteps)) + "\n\n")

            f.write("gen_mean for each 3D Gen Steps:\n")
            for test_iter in test_iters[:self.cfg.run_step_index]:
                f.write(f"{test_iter}:\n")
                f.write(", ".join(map(str, [lora_gen_means[t][test_iters.index(test_iter)] for t in timesteps])) + "\n")
            f.write("\n")

            f.write("Final gen_mean for each 3D Gen Steps:\n")
            for test_iter, gen_mean in zip(test_iters[:len(final_gen_means)], final_gen_means):
                f.write(f"{test_iter}: {gen_mean}\n")

    def forward(self, ref_dir, test_iter):
        # 从指定文件夹中随机选取渲染图片和参考图片
        render_dir = self.cfg.render_dir
        # lora_output_dirs = os.listdir(render_dir)
        # selected_prompt_dirs = np.random.choice(prompt_dirs, size=self.cfg.batch_size, replace=False)
        gens = []
        # prompts = []
        render_files = [
            f for f in os.listdir(os.path.join(render_dir, "test_six_views", f"{test_iter}_iteration"))
            if re.match(r"render_view_(\d+)\.png$", f) and 0 <= int(re.match(r"render_view_(\d+)\.png$", f).group(1)) <= 7
        ]
        # selected_render_file = np.random.choice(render_files)
        for render_file in render_files:
            render_path = os.path.join(render_dir, "test_six_views", f"{test_iter}_iteration", render_file)
            gens.append(Image.open(render_path))
        # prompts.append(prompt_dir.replace('_', ' '))

        ref_files = os.listdir(ref_dir)
        selected_ref_files = np.random.choice(ref_files, size=self.cfg.batch_size)
        refs = []
        for ref_file in selected_ref_files:
            ref_path = os.path.join(ref_dir, ref_file)
            refs.append(Image.open(ref_path))

        transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
        ])

        gens = torch.stack([transform(img) for img in gens])
        refs = torch.stack([transform(img) for img in refs])

        # batch = {
        #     "elevation": np.random.uniform(0, 360, self.cfg.batch_size),
        #     "azimuth": np.random.uniform(0, 360, self.cfg.batch_size),
        #     "camera_distances": np.random.uniform(1, 2, self.cfg.batch_size),
        # }

        return gens, refs

# 示例使用
if __name__ == "__main__":
    class Config:
        def __init__(self):
            # lora_step_name = "300"
            # self.trial_dir = f"/root/autodl-tmp/PathTracing_LucidDreamer/output_ood_per100step/lora_{lora_step_name}"
            self.trial_dir = f"output/ablation/load_strategy_0_3/A_DSLR_photo_of_the_Imperial_State_Crown_of_England/ood/no_load_lora"
            self.lora_dir = "root/autodl-tmp/PathTracing_LucidDreamer/resources/lora_ckpt"
            self.ref_base_dir = "resources/imgs_lora_sample_with_weight"
            # self.render_dir = f"/root/autodl-tmp/PathTracing_LucidDreamer/output_ood_per100step/resources/crown_{lora_step_name}_lora"
            self.render_dir = f"output/ablation/load_strategy_0_3/A_DSLR_photo_of_the_Imperial_State_Crown_of_England"
            self.image_size = 512
            self.batch_size = 32
            self.run_step_index = 51
            self.load_lora = False

    cfg = Config()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    ood_detector = OODDetector(cfg, device)
    ood_detector.detect_ood()
