# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# 
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# obj file dataset
import json
import os
import sys

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from numpy.lib.format import MAGIC_PREFIX
from PIL import Image
from torch.autograd import Variable


def load_obj(filename):
    vertices = []
    faces_vertex, faces_uv = [], []
    uvs = []
    with open(filename, "r") as f:
        for s in f:
            l = s.strip()
            if len(l) == 0:
                continue
            parts = l.split(" ")
            if parts[0] == "vt":
                uvs.append([float(x) for x in parts[1:]])
            elif parts[0] == "v":
                vertices.append([float(x) for x in parts[1:]])
            elif parts[0] == "f":
                faces_vertex.append([int(x.split("/")[0]) for x in parts[1:]])
                faces_uv.append([int(x.split("/")[1]) for x in parts[1:]])
    # make sure triangle ids are 0 indexed
    obj = {
        "verts": np.array(vertices, dtype=np.float32),
        "uvs": np.array(uvs, dtype=np.float32),
        "vert_ids": np.array(faces_vertex, dtype=np.int32) - 1,
        "uv_ids": np.array(faces_uv, dtype=np.int32) - 1,
    }
    return obj


def check_path(path):
    if not os.path.exists(path):
        sys.stderr.write("%s does not exist!\n" % (path))
        sys.exit(-1)


def load_krt(path):
    cameras = {}

    with open(path, "r") as f:
        while True:
            name = f.readline()
            if name == "":
                break

            intrin = [[float(x) for x in f.readline().split()] for i in range(3)]
            dist = [float(x) for x in f.readline().split()]
            extrin = [[float(x) for x in f.readline().split()] for i in range(3)]
            f.readline()

            cameras[name[:-1]] = {
                "intrin": np.array(intrin),
                "dist": np.array(dist),
                "extrin": np.array(extrin),
            }

    return cameras


class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        base_dir,
        krt_dir,
        framelistpath,
        size=1024,
        camset=None,
        valid_prefix=None,
        exclude_prefix=None,
    ):
        self.uvpath = "{}/unwrapped_uv_1024".format(base_dir)
        self.meshpath = "{}/tracked_mesh".format(base_dir)
        self.photopath = "{}/images".format(base_dir)
        self.size = size
        self.camera_ids = {}

        check_path(self.uvpath)
        check_path(self.meshpath)
        check_path(framelistpath)

        framelist = np.genfromtxt(framelistpath, dtype=np.str_)
        # Sample
        framelist = framelist[::100]
        self.mesh_topology = None

        # set cameras
        krt = load_krt(krt_dir)
        self.krt = krt
        self.cameras = list(krt.keys())
        for i, k in enumerate(self.cameras):
            self.camera_ids[k] = i

        if camset is not None:
            self.cameras = camset
        self.allcameras = sorted(self.cameras)

        # load train list (but check that images are not dropped!)
        self.framelist = []

        for i, x in enumerate(framelist):
            if i % 1000 == 0:
                print("checking {}".format(i))

            # filter valid prefixes
            if valid_prefix is not None:
                valid = False
                for p in valid_prefix:
                    if x[0].startswith(p):
                        valid = True
                        break
                if not valid:
                    continue

            if exclude_prefix is not None:
                valid = True
                for p in exclude_prefix:
                    if x[0].startswith(p):
                        valid = False
                        break
                if not valid:
                    continue

            # check if has average texture
            avgf = "{}/{}/average/{}.png".format(self.uvpath, x[0], x[1])

            if os.path.isfile(avgf) is not True:
                continue
            # check if has per-view uvwrap
            for i, cam in enumerate(self.cameras):
                f = tuple(x) + (cam,)
                path = "{}/{}/{}/{}.png".format(self.uvpath, f[0], f[2], f[1])
                if os.path.isfile(path) is True:
                    self.framelist.append(f)

        # compute view directions of each camera
        campos = {}
        for cam in self.cameras:
            extrin = krt[cam]["extrin"]
            campos[cam] = -np.dot(extrin[:3, :3].T, extrin[:3, 3])


        # load mean image and std
        texmean = np.asarray(
            Image.open("{}/tex_mean.png".format(base_dir)), dtype=np.float32
        )
        self.texmean = np.copy(np.flip(texmean, 0))
        self.texstd = float(np.genfromtxt("{}/tex_var.txt".format(base_dir)) ** 0.5)
        self.texmin = (
            np.zeros_like(self.texmean, dtype=np.float32) - self.texmean
        ) / self.texstd
        self.texmax = (
            np.ones_like(self.texmean, dtype=np.float32) * 255 - self.texmean
        ) / self.texstd

    def __len__(self):
        return len(self.framelist)

    def __getitem__(self, idx):
        sentnum, frame, cam = self.framelist[idx]

        # average image
        path = "{}/{}/average/{}.png".format(self.uvpath, sentnum, frame)
        avgtex = np.asarray(Image.open(path), dtype=np.float32)[::-1, ...]
        mask = avgtex == 0
        avgtex -= self.texmean
        avgtex /= self.texstd
        avgtex[mask] = 0.0
        avgtex = cv2.resize(avgtex, (self.size, self.size)).transpose((2, 0, 1))

        return {
            "avg_tex": avgtex,
        }
