import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt 
from tqdm import tqdm
from dataclasses import dataclass, field
import time

from copy import deepcopy

from argparse import ArgumentParser
from torchvision.utils import save_image


from pre_exp.STF import stf_targets, pred_score_batch_eps

import threestudio
from threestudio.utils.config import ExperimentConfig, load_config, parse_structured
from threestudio.systems.base import BaseLift3DSystem
from threestudio.data.uncond import RandomCameraDataModuleConfig, RandomCameraIterableDataset
from threestudio.models.guidance import stable_diffusion_guidance
from threestudio.models.prompt_processors import stable_diffusion_prompt_processor
from threestudio.utils.misc import get_device

from gaussiansplatting.gaussian_renderer import render
from gaussiansplatting.scene import GaussianModel
from gaussiansplatting.arguments import PipelineParams
from gaussiansplatting.scene.cameras import Camera
from gaussiansplatting.utils.sh_utils import SH2RGB
from gaussiansplatting.scene.gaussian_model import BasicPointCloud

from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config as diffusion_from_config_shape
from shap_e.models.download import load_model
from shap_e.models.download import load_config as shape_load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget
from shap_e.util.notebooks import decode_latent_mesh
import open3d as o3d

class OOD:
    """
        detect ood problem
    """
    global_step = 0
    load_type = 0
    radius = 4
    save = True
    load_from_path = True
    create_from_shape = False
    load_path = "outputs/gaussiandreamer-sd/a_fox@20240620-183916/save/last_3dgs.ply"
    image_size = 256
    # load_path = "outputs/gaussiandreamer-sd/A_plate_piled_high_with_chocolate_chip_cookies@20240324-212641/save/shape.ply"
    
    cfg: ExperimentConfig

    def __init__(self, args, extras, n_gpus) -> None:
        # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = get_device()
        print(f"ood device: {self.device}")
        self.cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus)
        self.configure()

    def shape(self):

        # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        device = self.device
        xm = load_model('transmitter', device=device)
        model = load_model('text300M', device=device)
        model.load_state_dict(torch.load('./load/shapE_finetuned_with_330kdata.pth', map_location=device)['model_state_dict'])
        diffusion = diffusion_from_config_shape(shape_load_config('diffusion'))

        batch_size = 1
        guidance_scale = 15.0
        prompt = str(self.cfg.system.prompt_processor.prompt)
        print('prompt',prompt)

        latents = sample_latents(
            batch_size=batch_size,
            model=model,
            diffusion=diffusion,
            guidance_scale=guidance_scale,
            model_kwargs=dict(texts=[prompt] * batch_size),
            progress=True,
            clip_denoised=True,
            use_fp16=True,
            use_karras=True,
            karras_steps=64,
            sigma_min=1e-3,
            sigma_max=160,
            s_churn=0,
        )
        render_mode = 'nerf' # you can change this to 'stf'
        size = 256 # this is the size of the renders; higher values take longer to render.

        cameras = create_pan_cameras(size, device)

        # 使用解码函数decode_latent_images和之前得到的潜在变量以及相机设置来生成图像。
        self.shapeimages = decode_latent_images(xm, latents[0], cameras, rendering_mode=render_mode)

        # 生成三角形网格（或者可能是点云？point cloud pc）
        pc = decode_latent_mesh(xm, latents[0]).tri_mesh()

        skip = 1
        coords = pc.verts
        rgb = np.concatenate([pc.vertex_channels['R'][:,None],pc.vertex_channels['G'][:,None],pc.vertex_channels['B'][:,None]],axis=1) 

        # 可能对顶点坐标和RGB颜色进行下采样
        coords = coords[::skip]
        rgb = rgb[::skip]
        # 记录点云中点的数量。
        self.num_pts = coords.shape[0]
        point_cloud = o3d.geometry.PointCloud()
        point_cloud.points = o3d.utility.Vector3dVector(coords)
        point_cloud.colors = o3d.utility.Vector3dVector(rgb)
        self.point_cloud = point_cloud

        return coords,rgb,0.4
    
    def add_points(self,coords,rgb):
        pcd_by3d = o3d.geometry.PointCloud()
        pcd_by3d.points = o3d.utility.Vector3dVector(np.array(coords))
        

        bbox = pcd_by3d.get_axis_aligned_bounding_box()
        np.random.seed(0)

        num_points = 1000000  
        # 在边界盒范围内生成num_points个随机坐标点。这些点是均匀分布的。
        points = np.random.uniform(low=np.asarray(bbox.min_bound), high=np.asarray(bbox.max_bound), size=(num_points, 3))


        kdtree = o3d.geometry.KDTreeFlann(pcd_by3d)


        points_inside = []
        color_inside= []
        # 通过遍历所有随机生成的点，使用KD树找出每个随机点的最近邻点：
        for point in points:
            _, idx, _ = kdtree.search_knn_vector_3d(point, 1)
            nearest_point = np.asarray(pcd_by3d.points)[idx[0]]
            if np.linalg.norm(point - nearest_point) < 0.01:  # 这个阈值可能需要调整
                points_inside.append(point)
                color_inside.append(rgb[idx[0]]+0.2*np.random.random(3))

                
                

        all_coords = np.array(points_inside)
        all_rgb = np.array(color_inside)
        all_coords = np.concatenate([all_coords,coords],axis=0)
        all_rgb = np.concatenate([all_rgb,rgb],axis=0)
        return all_coords,all_rgb

    def smpl(self):
        self.num_pts  = 50000
        mesh = o3d.io.read_triangle_mesh(self.load_path)
        point_cloud = mesh.sample_points_uniformly(number_of_points=self.num_pts)
        coords = np.array(point_cloud.points)
        shs = np.random.random((self.num_pts, 3)) / 255.0
        rgb = SH2RGB(shs)
        adjusment = np.zeros_like(coords)
        # 以下几行代码对coords数组中的坐标数据进行一个轴的调整，可能是坐标系转换：这将Y轴和Z轴的数据互换，并将X轴数据与Z轴的数据互换。
        adjusment[:,0] = coords[:,2]
        adjusment[:,1] = coords[:,0]
        adjusment[:,2] = coords[:,1]
        current_center = np.mean(adjusment, axis=0)
        center_offset = -current_center
        adjusment += center_offset
        return adjusment,rgb,0.5
    
    def pcb(self):
        # Since this data set has no colmap data, we start with random points
        if self.load_type==0:
            coords,rgb,scale = self.shape()
        elif self.load_type==1:
            coords,rgb,scale = self.smpl()
        else:
            raise NotImplementedError
        
        bound= self.radius*scale

        all_coords,all_rgb = self.add_points(coords,rgb)
        

        pcd = BasicPointCloud(points=all_coords *bound, colors=all_rgb, normals=np.zeros((all_coords.shape[0], 3)))

        return pcd
    
    def configure(self):
        self.data_cfg: RandomCameraDataModuleConfig
        self.data_cfg = parse_structured(RandomCameraDataModuleConfig, self.cfg.data)
        self.refdata_cfg: RandomCameraDataModuleConfig
        self.refdata_cfg = parse_structured(RandomCameraDataModuleConfig, self.cfg.refdata)

        dataset = RandomCameraIterableDataset(self.data_cfg)
        self.camera_dataloader = torch.utils.data.DataLoader(
                dataset,
                # very important to disable multi-processing if you want to change self attributes at runtime!
                # (for example setting self.width and self.height in update_step)
                num_workers=0,  # type: ignore
                batch_size=None,
                collate_fn=dataset.collate,
            )
        refdataset = RandomCameraIterableDataset(self.refdata_cfg)
        self.ref_camera_dataloader = torch.utils.data.DataLoader(
                refdataset,
                # very important to disable multi-processing if you want to change self attributes at runtime!
                # (for example setting self.width and self.height in update_step)
                num_workers=0,  # type: ignore
                batch_size=None,
                collate_fn=dataset.collate,
            )
        # for batch in camera_dataloader:
        #     print(batch)


        
        # white [1, 1, 1] or black [0, 0, 0]
        bg_color = [1, 1, 1] if True else [0, 0, 0]
        background_tensor = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
        self.renderbackground = background_tensor

        self.gaussian = GaussianModel(sh_degree=0)
        # gaussian.load_ply("outputs/gaussiandreamer-sd/A_plate_piled_high_with_chocolate_chip_cookies@20240324-212641/save/shape.ply")
        if self.load_from_path:
            print("load_from_path")
            self.gaussian.load_ply(self.load_path)
        elif self.create_from_shape:
            print("create_from_shape")
            point_cloud = self.pcb()
            self.cameras_extent = 4.0

            self.gaussian.create_from_pcd(point_cloud, self.cameras_extent)
        else:
            # FIXME load from random
            pass

        self.prompt_processor = threestudio.find(self.cfg.system.prompt_processor_type)(
            self.cfg.system.prompt_processor
        )
        self.prompt_utils = self.prompt_processor()

        self.lora_dir = os.path.join("outputs/lora_shape", f"{str(self.prompt_processor.prompt.replace(' ', '_'))}")
        # self.guidance = threestudio.find(self.cfg.system.guidance_type)(self.cfg.system.guidance)
        self.ref_guidance = threestudio.find(self.cfg.system.guidance_type)(self.cfg.system.guidance)
        # self.guidance.load_lora_params(
        #     load_unet="unet",
        #     load_step = 200,
        #     old_load_step = None,
        #     lora_dir = self.lora_dir,
        #     )
        self.ref_guidance.load_lora_params(
            load_unet="unet",
            load_step = 1200,
            old_load_step = None,
            lora_dir = self.lora_dir,
            )

    def img_save(self, imgs, name):
        for i in range(imgs.shape[0]): 
            # 构建每个文件的保存路径
            img_save_path = os.path.join(self.cfg.trial_dir, f'{name}_global_{self.global_step}_{i}.png')
            # depth_save_path = os.path.join(self.cfg.trial_dir, f'depth_{i}.png')
            # opacity_save_path = os.path.join(self.cfg.trial_dir, f'opacity_{i}.png')
            # 保存图像
            save_image(imgs[i], img_save_path)
            # save_image(render_pkg["depth"][i],    depth_save_path)
            # save_image(render_pkg["opacity"][i],  opacity_save_path)

    def rendered_generated_samples(self, batch):
        images = []
        depths = []
        # self.viewspace_point_list = []
        for id in range(batch['c2w_3dgs'].shape[0]):
            viewpoint_cam  = Camera(c2w = batch['c2w_3dgs'][id],FoVy = batch['fovy'][id],height = batch['height'],width = batch['width'])

            self.parser = ArgumentParser(description="Training script parameters")
            self.pipe = PipelineParams(self.parser)
            # 调用render函数渲染当前视点，并获得渲染包render_pkg，包含了多个渲染结果，如渲染后的图像、视空间点、可见性过滤器和半径等：
            render_pkg = render(viewpoint_cam, self.gaussian, self.pipe, self.renderbackground)
            image, viewspace_point_tensor, _, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
            
            # viewspace_point_list.append(viewspace_point_tensor)

            # if id == 0:
            #         self.radii = radii
            #     else:
            #         self.radii = torch.max(radii,self.radii)
                    
            # 处理深度信息，将其排列到正确的格式，并添加到depths列表中：
            depth = render_pkg["depth_3dgs"]
            # depth =  depth.permute(1, 2, 0)
                
            # image =  image.permute(1, 2, 0)
            # print(image.shape) 1024 1024
            images.append(image)
            depths.append(depth)
                
        images = torch.stack(images, 0)
        # print(images.shape)
        depths = torch.stack(depths, 0)
        # # 更新self.visibility_filter，这个过滤器可能用于决定哪些点是可见的：
        # self.visibility_filter = self.radii>0.0
        # 更新render_pkg字典，添加合成的RGB图像、深度信息，以及一个基于深度的不透明度（opacity）信息：
        render_pkg["comp_rgb"] = images
        render_pkg["depth"] = depths
        render_pkg["opacity"] = depths / (depths.max() + 1e-5)
        if self.save:
            self.img_save(render_pkg["comp_rgb"], 'render_image')
            self.img_save(render_pkg["depth"], 'render_depth')
            self.img_save(render_pkg["opacity"], 'render_opacity')

        # print(stf_targets)
        # print("Done!")
        return {
            **render_pkg,
        }
    
    def pretrained_generated_samples(self, guidance, batch):
        imgs = guidance.sample(
            self.prompt_utils, **batch, seed=self.global_step,
        )
        imgs =  imgs.permute(0, 3, 1, 2)
        # if self.save:
        #     self.img_save(imgs, 'pretrain_image')
        return {"comp_rgb": imgs}
        # return imgs

    def forward(self):
        batch = next(iter(self.camera_dataloader))
        refbatch = next(iter(self.ref_camera_dataloader))
        # print(batch)
        out1 = self.rendered_generated_samples(batch=batch)         # torch.Size([4, 3, 512, 512])
        # out1 = self.pretrained_generated_samples(self.guidance, batch=batch)       # torch.Size([4, 4, 512, 512])
        out2 = self.pretrained_generated_samples(self.ref_guidance, batch=batch)       # torch.Size([4, 4, 512, 512])
        out3 = self.pretrained_generated_samples(self.ref_guidance, batch=refbatch)       # torch.Size([16, 4, 512, 512])
        latent1 = F.interpolate(
            out1["comp_rgb"], (self.image_size, self.image_size), mode="bilinear", align_corners=False
        )
        latent2 = F.interpolate(
            out2["comp_rgb"], (self.image_size, self.image_size), mode="bilinear", align_corners=False
        )
        latent3 = F.interpolate(
            out3["comp_rgb"], (self.image_size, self.image_size), mode="bilinear", align_corners=False
        )
        if self.save:
            self.img_save(latent1, 'latent1')
            self.img_save(latent2, 'latent2')
            self.img_save(latent3, 'latent3')
        # print(f"latent1: {latent1.shape}")
        # print(f"latent2: {latent2.shape}")
        # 假设 rgb_BCHW 的形状是 [batch_size, 3, height, width]
        zeros = torch.zeros_like(latent1[:, :1, :, :])  # 创建一个全零的通道
        latent1 = torch.cat((latent1, zeros), dim=1)  # 将其添加到原始张量上
        latent2 = torch.cat((latent2, zeros), dim=1)  # 将其添加到原始张量上
        latent3 = torch.cat((latent3, zeros), dim=1)  # 将其添加到原始张量上
        # print(f"latent1: {latent1.shape}")
        # print(f"latent2: {latent2.shape}")
        return latent1, latent2, latent3, batch

    def fill_as(self, inputs, target) -> torch.Tensor:
        """
        Adds as many dimensions to the first tensor as it needs to reach the number of dimensions of the second tensor.
        Useful for broadcasting.
        """
        return inputs.view((*inputs.size(), *[1] * (len(target.size()) - len(inputs.size()))))
    
    def add_noise(self, x, sigma):
        sigma = torch.ones(x.shape[0], device=x.device) * sigma
        noise = torch.normal(0, 1, x.size(), device=x.device)
        perturbed_x = x + noise * self.fill_as(sigma, x)

        return perturbed_x, sigma

    def get_diff(self, x, ref, batch, timestep):
        # if sigma == None:
        #     sigma = sample_noise_level(opt, ref.shape[0], ref.device)
        # else:
        # self.guidance.scheduler.set_timesteps(timestep)   #但是后面不用这个
        x=x.to(self.device)
        t = torch.tensor([timestep]* x.shape[0], dtype=torch.long, device=self.device)
        # x_t, sigma = self.add_noise(x, timestep)
        noise = torch.randn_like(x)
        x_t = self.ref_guidance.scheduler.add_noise(x, noise, t)

        noise_pred = self.ref_guidance.get_noise_pred(
            x_t,
            # torch.tensor(timestep),
            t,
            # sigma.reshape(1),
            self.prompt_utils.get_text_embeddings(
                batch["elevation"], batch["azimuth"], batch["camera_distances"], view_dependent_prompting=False
            )
        )
        # self.guidance.scheduler.alphas_cumprod = self.guidance.scheduler.alphas_cumprod.to(self.device)
        # predicted_x = self.guidance.scheduler.step(noise_pred, t, x_t, eta=1)[
        #         "prev_sample"
        #     ]
        # stf = stf_targets(sigma, x_t, ref).reshape(-1, 4, self.image_size, self.image_size)   # BCHW
        atbar = torch.tensor(
            [self.ref_guidance.scheduler.alphas_cumprod[timestep]] * x_t.shape[0],
            device=self.device,
            )
        _, eps, _ = pred_score_batch_eps(ref, x_t, atbar)
        diff = F.mse_loss(eps, noise_pred)

        return diff
    
    def detect_ood(self):
        
        print("detect ood!")

        
        timesteps = [1,2,5,10,20,50,100,200,500,700]
        # sigma = [0.32, 0.5, 1, 2, 5, 10, 20, 40]
        diff_pretrain_mean = []
        diff_gen_mean = []
        diff_pretrain_std = []
        diff_gen_std = []
        for i in tqdm(range(len(timesteps)), position=0, desc="i", leave=False, colour='green', ncols=80):
            pre_temp = []
            gen_temp = []
            for j in tqdm(range(20), position=1, desc="j", leave=False, colour='red', ncols=80):
                if self.global_step == 5:
                    self.save = False
                gens, images, ref, batch = self.forward()
                # print(gens.shape)
                # print(images.shape)
                
                images = images.to(self.device)
                # ref = deepcopy(image)
                # ref,_ = next(iter(ref_loader))
                ref = ref.to(self.device)

                # z = torch.normal(0, 1, (ref.shape[0], opt.model_params.latent_dim), device=opt.device) #! 随机噪声 nb_gen应该是每个gpu上的batch大小 latent_dim是128
                # gen = generator(z)
                with torch.no_grad():
                    diff_image = self.get_diff(images, ref, batch, timesteps[i])
                    diff_gen = self.get_diff(gens, ref, batch, timesteps[i])
                torch.cuda.empty_cache()
                # print(diff_image.detach().cpu().item(), diff_gen.detach().cpu().item())
                pre_temp.append(diff_image.detach().cpu().item())
                gen_temp.append(diff_gen.detach().cpu().item())
                self.global_step +=1
            pre_mean = np.mean(pre_temp)
            pre_std = np.std(pre_temp)
            gen_mean = np.mean(gen_temp)
            gen_std = np.std(gen_temp)
            print(f'step:{i}, pre_mean: {pre_mean}, pre_std: {pre_std}, gen_mean: {gen_mean}, gen_std: {gen_std}')
            diff_pretrain_mean.append(pre_mean)
            diff_gen_mean.append(gen_mean)
            diff_pretrain_std.append(pre_std)
            diff_gen_std.append(gen_std)
        print('done')
        # x = np.array(timesteps)
        x = np.arange(len(timesteps))
        bar_width = 0.35
        plt.bar(x, diff_pretrain_mean, width=bar_width, yerr=diff_pretrain_std, label="Real image", linewidth=2, capsize=5)
        # xticks + width，表示的是X轴所有标签第二个柱子的起始位置
        plt.bar(x + bar_width, diff_gen_mean, width=bar_width, yerr=diff_gen_std, label="Generated image", linewidth=2, capsize=5)
        # 设置x轴刻度
        plt.xticks(x + bar_width / 2, timesteps)
        # x_len = np.array(x)+0.15
        # label = [0.32, 0.5, 1, 2, 3, 5, 10, 20, 40]
        # plt.xticks(x_len, label,fontsize = 10)
        plt.yscale("log") #! 'linear', 'log', 'symlog', 'asinh', 'logit', 'function', 'functionlog'
        plt.legend(loc='best')
        plt.xlabel('Timesteps')
        plt.ylabel('Mse loss of STF and Unet prediction')
        plt.savefig(f'{self.cfg.trial_dir}/ood.png')
        # 保存数据
        np.savez(
            # 'path/to/your/folder/data.npz', 
            os.path.join(self.cfg.trial_dir, f'data.npz'),
            timesteps=timesteps, 
            diff_pretrain_mean=diff_pretrain_mean, diff_pretrain_std=diff_pretrain_std, 
            diff_gen_mean=diff_gen_mean, diff_gen_std=diff_gen_std
            )
