import os
import sys
import cv2
import tqdm
import numpy as np
import math

import argparse
from omegaconf import OmegaConf

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

from PIL import Image
import rembg
import pytorch_ssim

from gaussian_splatting import BasicPointCloud

base_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(base_dir, 'shape-e'))
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config as diffusion_from_config_shape
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import decode_latent_mesh

from cam_utils import Camera, RandomCameraCam
from gaussian_splatting import Renderer

from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler, AutoencoderKL
from ip_adapter import IPAdapter

from clipconv import CLIPConvLoss

class ControlGaussian:
    def __init__(self, opt):
        self.opt = opt  # shared with the trainer's opt to support in-place modification of rendering parameters.
        self.seed = "random"
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.image_count = 0
        self.ip_image_count = 0

        # camera cam
        self.cam = RandomCameraCam(opt)

        # input sketch
        self.input_sketch_PIL = None
        self.input_sketch_torch = None

        # input text
        self.prompt = ""
        self.negative_prompt = ""

        # point cloud
        self.cameras_extent = 4.0

        # ControlNet
        self.controlnet_pipe = None
        self.controlnet_img = None
        self.controlnet_img_torch = None
        self.controlnet_img_mask = None
        self.controlnet_img_mask_torch = None
        self.shapeE_input_PIL = None

        # renderer
        self.renderer = Renderer(sh_degree=self.opt.sh_degree)
        self.gaussain_scale_factor = 1

        # stable_diffusion
        self.guidance_stable_diffusion = None
        self.viewspace_point_list = None
        self.visibility_filter = None
        self.radii = None

        # transfer
        self.supervise_mean = None
        self.supervise_variance = None

        # IP Adapter
        self.IP_pipe = None
        self.IP_noise_scheduler = None
        self.IP_vae = None
        self.IP_ip_model = None
        self.supervise_whitebg_PIL = None
        self.multi_supervise_azimuth = []
        self.multi_supervise_elevation = []

        # loss
        self.sketch_loss_calculator = None
        self.ssim_loss = None
        self.sds_loss_azimuth = None
        self.sds_loss_elevation = None
        self.clip_loss = None


        # training settings
        self.training = False
        self.optimizer = None
        self.step = 0
        self.train_steps = 1  # steps per rendering loop

        # load input data from cmdline
        if self.opt.input_sketch is not None:
            self.load_input(self.opt.input_sketch)

        # override prompt from cmdline
        if self.opt.input_prompt is not None:
            self.prompt = self.opt.input_prompt
        if self.opt.negative_prompt is not None:
            self.negative_prompt = self.opt.negative_prompt

    def seed_everything(self):
        try:
            seed = int(self.seed)
        except:
            seed = np.random.randint(0, 1000000)

        os.environ["PYTHONHASHSEED"] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.sofamark = True

    def load_input(self, file):
        # preprocess input sketch
        print(f'[INFO] load sketch from {file}...')

        input_sketch_PIL = Image.open(file)
        input_sketch_PIL_copy = input_sketch_PIL
        input_sketch_PIL_gray_image = input_sketch_PIL.convert('L')
        threshold = 128
        self.input_sketch_PIL = input_sketch_PIL_gray_image.point(lambda p: 0 if p > threshold else 255)

        resize = transforms.Resize((512, 512))
        input_sketch_resized = resize(input_sketch_PIL_copy)
        to_tensor = transforms.ToTensor()
        input_tensor = to_tensor(input_sketch_resized)
        if input_tensor.shape[0] == 1:
            input_tensor = input_tensor.repeat(3, 1, 1)
        input_tensor = input_tensor.unsqueeze(0)  

        self.input_sketch_torch = input_tensor

    def prepare_controlnet(self):
        # load control_v11p_sd15_canny
        from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler
        controlnet = ControlNetModel.from_pretrained(
            "load/control_v11p_sd15_canny",
            # "lllyasviel/control_v11p_sd15_canny",
            torch_dtype=torch.float16)
        self.controlnet_pipe = StableDiffusionControlNetPipeline.from_pretrained(
            "load/stable-diffusion-v1-5",
            # "runwayml/stable-diffusion-v1-5",
            controlnet=controlnet,
            torch_dtype=torch.float16
        )
        self.controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(self.controlnet_pipe.scheduler.config)
        self.controlnet_pipe.enable_model_cpu_offload()

        text_prompt = self.opt.input_prompt
        generator = torch.manual_seed(46)
        control_image = self.input_sketch_PIL
        controlnet_output_RGBImage_PIL = self.controlnet_pipe(
            text_prompt,
            num_inference_steps=20,
            generator=generator,
            image=control_image,
            # controlnet_conditioning_scale=0.5
        ).images[0]

        bgr_image = np.array(controlnet_output_RGBImage_PIL)
        rgb_image = bgr_image[..., ::-1]
        controlnet_output_RGBImage_PIL = Image.fromarray(rgb_image, 'RGB')
        controlnet_output_RGBImage_PIL = controlnet_output_RGBImage_PIL.convert("RGBA")

        final_rgba = self.remove_rgb_background(controlnet_output_RGBImage_PIL)

        final_rgba_copy = final_rgba.copy().astype(np.uint8)
        self.shapeE_input_PIL = Image.fromarray(final_rgba_copy)

        final_rgba = final_rgba.astype(np.float32) / 255.0
        self.controlnet_img_mask = final_rgba[..., 3:]
        # white background
        self.controlnet_img = final_rgba[..., :3] * self.controlnet_img_mask + (1 - self.controlnet_img_mask)
        self.controlnet_img = self.controlnet_img[..., ::-1].copy()

    def point_cloud_initialization(self):
        xm = load_model('transmitter', device=self.device)
        model = load_model('image300M', device=self.device)
        diffusion = diffusion_from_config_shape(load_config('diffusion'))

        batch_size = 1
        guidance_scale = 3.0

        latents = sample_latents(
            batch_size=batch_size,
            model=model,
            diffusion=diffusion,
            guidance_scale=guidance_scale,
            model_kwargs=dict(images=[self.shapeE_input_PIL] * batch_size),
            progress=True,
            clip_denoised=True,
            use_fp16=True,
            use_karras=True,
            karras_steps=64,
            sigma_min=1e-3,
            sigma_max=160,
            s_churn=0,
        )

        pc = decode_latent_mesh(xm, latents[0]).tri_mesh()
        coords = pc.verts

        rgb = np.concatenate(
            [pc.vertex_channels['B'][:, None], pc.vertex_channels['G'][:, None], pc.vertex_channels['R'][:, None]],
            axis=1)

        skip = 4
        coords = coords[::skip]
        rgb = rgb[::skip]

        point_cloud = BasicPointCloud(points=coords, colors=rgb, normals=np.zeros((coords.shape[0], 3)))
        print(f'[INFO] point cloud Initialization Finish !')
        return point_cloud

    def remove_rgb_background(self, controlnet_output_RGBImage_PIL):
        session = rembg.new_session(model_name='u2net')

        controlnet_output_RGBImage_image = np.array(controlnet_output_RGBImage_PIL)

        # sofave background
        print(f'[INFO] background removal...')
        if controlnet_output_RGBImage_image.dtype != np.uint8:
            controlnet_output_RGBImage_image = controlnet_output_RGBImage_image.astype(np.uint8)
        sofaved_image = rembg.remove(controlnet_output_RGBImage_image, session=session)  # [H, W, 4]
        mask = sofaved_image[..., -1] > 100

        print(f'[INFO] recenter...')
        final_size, border_ratio = 512, 0.2
        final_rgba = np.zeros((final_size, final_size, 4), dtype=np.uint8)
        coords = np.nonzero(mask)
        x_min, x_max = coords[0].min(), coords[0].max()
        y_min, y_max = coords[1].min(), coords[1].max()
        h = x_max - x_min
        w = y_max - y_min
        desired_size = int(final_size * (1 - border_ratio))
        scale = desired_size / max(h, w)
        h2 = int(h * scale)
        w2 = int(w * scale)
        x2_min = (final_size - h2) // 2
        x2_max = x2_min + h2
        y2_min = (final_size - w2) // 2
        y2_max = y2_min + w2
        final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(sofaved_image[x_min:x_max, y_min:y_max], (w2, h2),
                                                              interpolation=cv2.INTER_AREA)

        return final_rgba

    def prepare_train(self):
        # The number of iterations is set to 0 
        self.step = 0

        # setup training
        self.renderer.gaussian.training_setup(self.opt)

        # initialise self.optimizer
        self.optimizer = self.renderer.gaussian.optimizer

        # input sketch to torch format (B, C, H, W)
        if self.controlnet_img is not None:
            self.controlnet_img_torch = torch.from_numpy(self.controlnet_img).permute(2, 0, 1).unsqueeze(0).to(self.device)
            self.controlnet_img_torch = F.interpolate(self.controlnet_img_torch, (self.opt.ref_size, self.opt.ref_size),
                                                 mode="bilinear", align_corners=False)

            self.supervise_mean = torch.mean(self.controlnet_img_torch, dim=[0, 2, 3])
            self.supervise_variance = torch.var(self.controlnet_img_torch, dim=[0, 2, 3], unbiased=False)

            controlnet_img_torch_temp = self.controlnet_img_torch.clone()
            controlnet_img_torch_temp = controlnet_img_torch_temp.squeeze(0).permute(1, 2, 0).detach().cpu()
            controlnet_img_torch_temp = (controlnet_img_torch_temp * 255).byte()
            self.supervise_whitebg_PIL = Image.fromarray(controlnet_img_torch_temp.numpy(), 'RGB')

            self.controlnet_img_mask_torch = torch.from_numpy(self.controlnet_img_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)
            self.controlnet_img_mask_torch = F.interpolate(self.controlnet_img_mask_torch, (self.opt.ref_size, self.opt.ref_size),
                                                 mode="bilinear", align_corners=False)

        if self.guidance_stable_diffusion is None:
            print(f"[INFO] loading stable diffusion...")
            from guidance.stable_diffusion_guidance import StableDiffusionGuidance
            self.guidance_stable_diffusion = StableDiffusionGuidance(self.device, self.opt.input_prompt, self.opt.negative_prompt)
            print(f"[INFO] loaded stable diffusion Successful !")

        self.IP_noise_scheduler = DDIMScheduler(
            num_train_timesteps=1000,
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            clip_sample=False,
            set_alpha_to_one=False,
            steps_offset=1,
        )
        self.IP_vae = AutoencoderKL.from_pretrained("load/IP_Adapter/sd-vae-ft-mse").to(dtype=torch.float16)
        self.IP_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
            "load/stable-diffusion-v1-5",
            torch_dtype=torch.float16,
            scheduler=self.IP_noise_scheduler,
            vae=self.IP_vae,
            feature_extractor=None,
            safety_checker=None
        )
        self.IP_ip_model = IPAdapter(
            self.IP_pipe,
            "load/IP_Adapter/image_encoder/",
            "load/IP_Adapter/ip-adapter_sd15.bin",
            self.device)

        self.sketch_loss_calculator = CLIPConvLoss()

    def train_step(self):
        starter = torch.cuda.Event(enable_timing=True)
        ender = torch.cuda.Event(enable_timing=True)
        starter.record()

        initial_azimuth = [0, 120, 240]
        initial_elevation = [20, 140, 260]
        for _ in range(self.train_steps):
            self.step += 1
            step_ratio = min(1, self.step / self.opt.iters)

            # update stable_diffusion step
            self.guidance_stable_diffusion.update_step(self.step)

            # update lr
            self.renderer.gaussian.update_learning_rate(self.step)

            loss = 0

            # random view (SDS loss + MSE loss + CLIP loss)
            supervise_mean = []
            supervise_variance = []
            self.viewspace_point_list = []
            images = []
            if (self.step - 1) % 6 in [0, 1, 2]:
                cam_params = self.cam.surroundAzimuth(initial_azimuth=initial_azimuth[(self.step - 1) % len(initial_azimuth)])
                self.sds_loss_azimuth = 1
                self.sds_loss_elevation = 0
            elif (self.step - 1) % 6 in [3, 4, 5]:
                cam_params = self.cam.surroundElevation(initial_elevation=initial_elevation[(self.step - 1) % len(initial_elevation)])
                self.sds_loss_azimuth = 0
                self.sds_loss_elevation = 1

            for i in range(self.opt.stable_diffusion_render_batch):
                cur_cam = Camera(c2w=cam_params['c2w_3dgs'][i],
                                 FoVy=cam_params['fovy'][i],
                                 height=cam_params['height'],
                                 width=cam_params['width'])

                bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
                out = self.renderer.render(cur_cam, bg_color=bg_color)

                image = out["image"]
                images.append(image)
                viewspace_point_tensor = out["viewspace_points"]
                self.viewspace_point_list.append(viewspace_point_tensor)
                out_radii = out["radii"]
                if i == 0:
                    self.radii = out_radii
                else:
                    self.radii = torch.max(out_radii, self.radii)

                if self.image_count % 24 == 0:
                    self.clip_loss = self.sketch_loss_calculator(self.input_sketch_torch, image.unsqueeze(0))

                render_temp = image.clone()
                render_temp = render_temp.squeeze(0).permute(1, 2, 0).detach().cpu()
                render_temp = (render_temp * 255).byte()
                render_PIL = Image.fromarray(render_temp.numpy(), 'RGB')

                self.image_count += 1

                if ((self.step - 1) * 4 + i) % (3 * 20 * 4) == 0:
                    self.multi_supervise_azimuth = []
                elif ((self.step - 1) * 4 + i) % (3 * 20 * 4) == (3 * 4):
                    self.multi_supervise_elevation = []

                if (self.step - 1) % (3 * 20) in [0, 1, 2]:
                    IP_multiview_supervise_image_PIL = self.IP_ip_model.generate(
                        pil_image=self.supervise_whitebg_PIL,
                        num_samples=1,
                        num_inference_steps=10,
                        seed=42,
                        image=render_PIL,
                        strength=0.5
                    )[0]

                    self.ip_image_count += 1
                    self.multi_supervise_azimuth.append(IP_multiview_supervise_image_PIL)

                elif (self.step - 1) % (3 * 20) in [3, 4, 5]:
                    IP_multiview_supervise_image_PIL = self.IP_ip_model.generate(
                        pil_image=self.supervise_whitebg_PIL,
                        num_samples=1,
                        num_inference_steps=10,
                        seed=42,
                        image=render_PIL,
                        strength=0.5
                    )[0]

                    self.ip_image_count += 1
                    self.multi_supervise_elevation.append(IP_multiview_supervise_image_PIL)

                if (self.step - 1) % 6 in [0, 1, 2]:
                    IP_multiview_supervise_image_array = np.array(self.multi_supervise_azimuth[((self.step - 1) * 4 % 12) + i])
                elif (self.step - 1) % 6 in [3, 4, 5]:
                    IP_multiview_supervise_image_array = np.array(self.multi_supervise_elevation[((self.step - 1) * 4 % 12) + i])

                IP_multiview_supervise_image_array = IP_multiview_supervise_image_array.astype(np.float32) / 255.0
                IP_multiview_supervise_image_tensor = torch.from_numpy(IP_multiview_supervise_image_array).permute(
                    2, 0, 1).unsqueeze(0).to(self.device)
                IP_multiview_supervise_tensor = F.interpolate(IP_multiview_supervise_image_tensor,
                                                              (self.opt.ref_size, self.opt.ref_size),
                                                              mode="bilinear",
                                                              align_corners=False)

                supervise_mean.append(torch.mean(IP_multiview_supervise_tensor, dim=[0, 2, 3]))
                supervise_variance.append(torch.var(IP_multiview_supervise_tensor, dim=[0, 2, 3], unbiased=False))

                # MSE loss
                if (self.step - 1) > 80 and (self.step - 1) % (6 * 8) in range(0, 6):
                    lambda_pose = abs((self.sds_loss_azimuth * 1 * math.cos(math.radians(cam_params['azimuth'][i])) + self.sds_loss_elevation * 0.3 * math.cos(math.radians(cam_params['elevation'][i] - 20))))
                    loss = loss + 10_000_00 * lambda_pose * step_ratio * F.mse_loss(image.unsqueeze(0), IP_multiview_supervise_tensor)



            images = torch.stack(images, 0)
            self.visibility_filter = self.radii > 0.0

            # compute SDS loss
            diffusion_SDS_loss = self.guidance_stable_diffusion.train_step(
                images,
                cam_params["elevation"],
                cam_params["azimuth"],
                cam_params["camera_distances"],
                supervise_mean,
                supervise_variance,
                as_latents=False,
            )

            loss = loss + (self.sds_loss_azimuth * 1 + self.sds_loss_elevation * 0.3) * diffusion_SDS_loss * self.guidance_stable_diffusion.C(self.opt.stable_diffusion_loss_lambda_sds, self.step) \
                + 10000 * self.clip_loss
            self.clip_loss = 0

            print(f'[INFO] total loss in step{self.step} is: {loss.item()}...')

            # optimize step
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
                viewspace_point_tensor_grad = torch.zeros_like(self.viewspace_point_list[0])
                for idx in range(len(self.viewspace_point_list)):
                    viewspace_point_tensor_grad = viewspace_point_tensor_grad + self.viewspace_point_list[idx].grad
                # Keep track of max radii in image-space for pruning
                self.renderer.gaussian.max_radii2D[self.visibility_filter] = torch.max(
                    self.renderer.gaussian.max_radii2D[self.visibility_filter], self.radii[self.visibility_filter])

                self.renderer.gaussian.add_densification_stats(viewspace_point_tensor_grad, self.visibility_filter)

                if self.step > 100 and self.step % self.opt.densify_interval == 0:  # 300 100
                    self.renderer.gaussian.densify_and_prune(max_grad=0.0002, min_opacity=0.05, extent=self.cameras_extent, max_screen_size=0.025)

        ender.record()
        torch.cuda.synchronize()
        t = starter.elapsed_time(ender)

    @torch.no_grad()
    def save_model(self, mode='loss_after'):
        os.makedirs(self.opt.outdir, exist_ok=True)
        path = os.path.join(self.opt.outdir, self.opt.save_path + mode + '_model.ply')
        self.renderer.gaussian.save_ply(path)
        print(f"[INFO] save model to {path}.")



    def train(self, iters):
        self.seed_everything()
        # load controlnet and process output image
        self.prepare_controlnet()

        # generate gaussian from point cloud
        if self.opt.load is not None:
            self.renderer.initialize(self.opt.load)
            print(f"[INFO] generate gaussian from load ply.")
        else:
            point_cloud = self.point_cloud_initialization()
            self.renderer.gaussian.create_from_pcd(point_cloud, self.cameras_extent)
            print(f"[INFO] generate gaussian from shape-E point cloud initialization.")

        self.save_model(mode='point_cloud_gaussian')

        # optimise gaussian
        self.prepare_train()
        for i in tqdm.trange(iters):
            self.train_step()

        self.save_model(mode='loss_after')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", required=True, help="path to the yaml config file")
    args, extras = parser.parse_known_args()

    # override default config from cli
    args.config = "configs/image.yaml"
    opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))

    ControlGaussian = ControlGaussian(opt)
    ControlGaussian.train(opt.iters)
