import os
import ot
import Imath
import numpy as np
import OpenEXR
from matplotlib import colors
from PIL import Image
import mitsuba as mi
import matplotlib.pyplot as plt
import glob

mi.set_variant('scalar_rgb')

# note that sampler is changed to 'independent' and the ldrfilm is changed to hdrfilm
xml_head = """
<scene version="0.6.0">
    <integrator type="path">
        <integer name="maxDepth" value="-1"/>
    </integrator>
    <sensor type="perspective">
        <float name="farClip" value="100"/>
        <float name="nearClip" value="0.1"/>
        <transform name="toWorld">
            <lookat origin="3,3,3" target="0,0,0" up="0,0,1"/>
        </transform>
        <float name="fov" value="25"/>
        <sampler type="independent">
            <integer name="sampleCount" value="256"/>
        </sampler>
        <film type="hdrfilm">
            <integer name="width" value="400"/>
            <integer name="height" value="300"/>
            <rfilter type="gaussian"/>
        </film>
    </sensor>
    <bsdf type="roughplastic" id="surfaceMaterial">
        <string name="distribution" value="ggx"/>
        <float name="alpha" value="0.05"/>
        <float name="intIOR" value="1.46"/>
        <rgb name="diffuseReflectance" value="1,1,1"/> <!-- default 0.5 -->
    </bsdf>
"""

# I also use a smaller point size
xml_ball_segment = """
    <shape type="sphere">
        <float name="radius" value="0.025"/>
        <transform name="toWorld">
            <translate x="{}" y="{}" z="{}"/>
        </transform>
        <bsdf type="diffuse">
            <rgb name="reflectance" value="{},{},{}"/>
        </bsdf>
    </shape>
"""

xml_tail = """
    <shape type="rectangle">
        <ref name="bsdf" id="surfaceMaterial"/>
        <transform name="toWorld">
            <scale x="10" y="10" z="1"/>
            <translate x="0" y="0" z="-0.5"/>
        </transform>
    </shape>
    <shape type="rectangle">
        <transform name="toWorld">
            <scale x="10" y="10" z="1"/>
            <lookat origin="-4,4,20" target="0,0,0" up="0,0,1"/>
        </transform>
        <emitter type="area">
            <rgb name="radiance" value="6,6,6"/>
        </emitter>
    </shape>
</scene>
"""


def colormap(x, y, z, color_name="rainbow", alpha=0.7):
    if color_name != "rainbow":
        vec = [x * alpha for x in colors.to_rgb(color_name)]
    else:
        vec = np.array([x, y, z])
        vec = np.clip(vec, 0.001, 1.0)
        norm = np.sqrt(np.sum(vec**2))
        vec /= norm

    return [vec[0], vec[1], vec[2]]


def standardize_bbox(pcl, points_per_object):
    pt_indices = np.random.choice(pcl.shape[0], points_per_object, replace=False)
    np.random.shuffle(pt_indices)
    pcl = pcl[pt_indices]  # n by 3
    mins = np.amin(pcl, axis=0)
    maxs = np.amax(pcl, axis=0)
    center = (mins + maxs) / 2.0
    scale = np.amax(maxs - mins)
    print("Center: {}, Scale: {}".format(center, scale))
    result = ((pcl - center) / scale).astype(np.float32)  # [-0.5, 0.5]
    return result


# only for debugging reasons
def writeply(vertices, ply_file):
    sv = np.shape(vertices)
    points = []
    for v in range(sv[0]):
        vertex = vertices[v]
        points.append("%f %f %f\n" % (vertex[0], vertex[1], vertex[2]))
    print(np.shape(points))
    file = open(ply_file, "w")
    file.write(
        """ply
    format ascii 1.0
    element vertex %d
    property float x
    property float y
    property float z
    end_header
    %s
    """
        % (len(vertices), "".join(points))
    )
    file.close()


# as done in https://gist.github.com/drakeguan/6303065
def ConvertEXRToJPG(exrfile, jpgfile):
    File = OpenEXR.InputFile(exrfile)
    PixType = Imath.PixelType(Imath.PixelType.FLOAT)
    DW = File.header()["dataWindow"]
    Size = (DW.max.x - DW.min.x + 1, DW.max.y - DW.min.y + 1)

    rgb = [np.fromstring(File.channel(c, PixType), dtype=np.float32) for c in "RGB"]
    for i in range(3):
        rgb[i] = np.where(
            rgb[i] <= 0.0031308, (rgb[i] * 12.92) * 255.0, (1.055 * (rgb[i] ** (1.0 / 2.4)) - 0.055) * 255.0
        )

    rgb8 = [Image.frombytes("F", Size, c.tostring()).convert("L") for c in rgb]
    # rgb8 = [Image.fromarray(c.astype(int)) for c in rgb]
    Image.merge("RGB", rgb8).save(jpgfile, "JPEG", quality=95)


def main(argv):
    images = []
    if len(argv) < 2:
        print("filename to npy/ply is not passed as argument. terminated.")
        return

    pathToFile = argv[1]
    if len(argv) == 2:
        color_name = "rainbow"
    else:
        color_name = argv[2]

    filename, file_extension = os.path.splitext(pathToFile)
    folder = os.path.dirname(pathToFile)
    filename = os.path.basename(pathToFile)

    # for the moment supports npy and ply
    if file_extension == ".npy":
        pclTime = np.load(pathToFile)
        pclTimeSize = np.shape(pclTime)
    elif file_extension == ".npz":
        pclTime = np.load(pathToFile)
        pclTime = pclTime["pred"]
        pclTimeSize = np.shape(pclTime)
    else:
        print("unsupported file format.")
        return

    if len(np.shape(pclTime)) < 3:
        pclTimeSize = [1, np.shape(pclTime)[0], np.shape(pclTime)[1]]
        pclTime.resize(pclTimeSize)

    for pcli in range(0, pclTimeSize[0]):
        pcl = pclTime[pcli, :, :]

        pcl = standardize_bbox(pcl, 2048)
        pcl = pcl[:, [2, 0, 1]]
        pcl[:, 0] *= -1
        pcl[:, 2] += 0.0125

        xml_segments = [xml_head]
        for i in range(pcl.shape[0]):
            color = colormap(pcl[i, 0] + 0.5, pcl[i, 1] + 0.5, pcl[i, 2] + 0.5 - 0.0125, color_name)
            xml_segments.append(xml_ball_segment.format(pcl[i, 0], pcl[i, 1], pcl[i, 2], *color))
        xml_segments.append(xml_tail)

        xml_content = str.join("", xml_segments)

        xmlFile = "%s/%s_%02d.xml" % (folder, filename, pcli)

        with open(xmlFile, "w") as f:
            f.write(xml_content)
        f.close()

        exrFile = "%s/%s_%02d.exr" % (folder, filename, pcli)
        if not os.path.exists(exrFile):
            print(["Running Mitsuba, writing to: ", xmlFile])
            # subprocess.run([PATH_TO_MITSUBA2, xmlFile])
            scene = mi.load_file(xmlFile)
            # Render the scene
            image = mi.render(scene)
        else:
            print("skipping rendering because the EXR file already exists")

        # png = "%s/%s_%02d.jpg" % (folder, filename, pcli)
        # mi.util.write_bitmap(png, image)
        images.append(image)

    return images


# Load checkpoint
seed = 2
np.random.seed(seed)
def compute_true_Wasserstein(X,Y,p=2):
    M = ot.dist(X, Y)
    a = np.ones((X.shape[0],)) / X.shape[0]
    b = np.ones((Y.shape[0],)) / Y.shape[0]
    return ot.emd2(a, b, M)

sw_points = np.load(f"saved/sw_lr0.01_src8_tgt21_seed{seed}_points.npy")
twd_points = np.load(f"saved/twd_lr0.01_src8_tgt21_seed{seed}_points.npy")
fw_points = np.load(f"saved/fw_twd_lr0.01_src8_tgt21_seed{seed}_points.npy")
fw_rp_points = np.load(f"saved/fw_twd_rp_lr0.01_src8_tgt21_seed{seed}_points.npy")

# Compute W2 distances
sw_dist, twd_dist, fw_dist, fw_rp_dist = [], [], [], []
Y = sw_points[-1]
for i in range(sw_points.shape[0]):
    sw_dist.append(compute_true_Wasserstein(sw_points[i], Y))
    twd_dist.append(compute_true_Wasserstein(twd_points[i], Y))
    fw_dist.append(compute_true_Wasserstein(fw_points[i], Y))
    fw_rp_dist.append(compute_true_Wasserstein(fw_rp_points[i], Y))


# Render images
sample = main([None, f"saved/sw_lr0.01_src8_tgt21_seed{seed}_points.npy", "silver"])
sw_images = main([None, f"saved/sw_lr0.01_src8_tgt21_seed{seed}_points.npy", "lightcoral"])
twd_images = main([None, f"saved/twd_lr0.01_src8_tgt21_seed{seed}_points.npy", "gold"])
fw_images = main([None, f"saved/fw_twd_lr0.01_src8_tgt21_seed{seed}_points.npy", "lightgreen"])
fw_rp_images = main([None, f"saved/fw_twd_rp_lr0.01_src8_tgt21_seed{seed}_points.npy", "deepskyblue"])



images = [("Source", None, sample[0]), ("SW", sw_dist[-2], sw_images[-2]), 
          ("Db-TSW", twd_dist[-2], twd_images[-2]), ("FW-TSW", fw_dist[-2], fw_images[-2]), 
          ("FW-TSW*", fw_rp_dist[-2], fw_rp_images[-2]), ("Target", None, sample[-1])]

cols = 3; rows = 2
plt.figure(figsize=(5 * cols, 5 * rows))
# Plot each image
for idx, item in enumerate(images):
    plt.subplot(rows, cols, idx + 1)
    plt.imshow(item[2])
    if item[1] is None:
        plt.title(f"{item[0]}", fontsize=16)
    else:
        plt.title(f"{item[0]}, W2: {item[1]:.2e}", fontsize=16)
    plt.axis('off')
plt.tight_layout()
plt.axis('off')
plt.savefig(f"point-cloud.pdf")
plt.close()

# Clean up
xml_files = glob.glob(os.path.join("saved/", '*.xml'))
for file_path in xml_files:
    try:
        os.remove(file_path)
    except Exception as e:
        print(f"Error deleting {file_path}: {e}")