import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline
import matplotlib.pyplot as plt
import os
import cv2
import argparse
from glob import glob
from tqdm import tqdm


def inv_project(depth, intrinsics):
    h, w = depth.shape

    f, _, cx, cy = intrinsics[:4]
    z = depth  # z = depth

    xx = np.tile(np.arange(w, dtype=np.float32)[None, :], (h, 1))
    yy = np.tile(np.arange(h, dtype=np.float32)[:, None], (1, w))

    x = (xx - cx) * z / f
    y = (yy - cy) * z / f

    pc = np.concatenate(
        [
            x[:, :, None],
            y[:, :, None],
            z[:, :, None],
        ],
        axis=-1,
    )

    return pc

def names2model(names):
    # C[i] * X^n * Y^m
    return ' + '.join([
        f"C[{i}]*{n.replace(' ','*')}"
        for i,n in enumerate(names)])

def quadric_fitting(pts, order=2, offset=0.05, verbose=False):
    X, Y, Z = pts[..., 0], pts[..., 1], pts[..., 2]
    model = make_pipeline(
        PolynomialFeatures(degree=order),
        LinearRegression(fit_intercept=False))
    model.fit(np.c_[X, Y], Z)

    m = names2model(model[0].get_feature_names_out(['X', 'Y']))
    C = model[1].coef_.T  # coefficients
    r2 = model.score(np.c_[X, Y], Z)  # R-squared

    # print summary
    if verbose:
        print(f'data = {Z.size}x3')
        print(f'model = {m}')
        print(f'coefficients =\n{C}')
        print(f'R2 = {r2}')


    # # uniform grid covering the domain of the data
    # XX,YY = np.meshgrid(np.linspace(X.min() - offset, X.max() + offset, 20), np.linspace(Y.min()- offset, Y.max() + offset, 20))

    # # evaluate model on grid
    # ZZ = model.predict(np.c_[XX.flatten(), YY.flatten()]).reshape(XX.shape)

    # surface = np.stack([XX, YY, ZZ], axis=-1)

    corrected_Z = model.predict(np.c_[X.flatten(), Y.flatten()]).reshape(X.shape)
    return corrected_Z, r2


if __name__ == "__main__":
    
    
    parser = argparse.ArgumentParser(description="Quadric fitting and correction.")

    parser.add_argument("--data_root", dest="data_root", help="path to rgb image")
    parser.set_defaults(data_root="/data/nerf_dataset/raw_replica/Replica_full")

    parser.add_argument("--output_path", dest="output_path", help="path to where output image should be stored")
    parser.set_defaults(output_path="/data/nerf_dataset/raw_replica/Replica_full")


    args = parser.parse_args()
    
    root_path = args.data_root

    intrinsc = np.array([308, 308, 307.5, 171.5])
    topk = 20
    Theashold = 0.85
    scenes = ["room_0", "room_1", "room_2", "office_0", "office_1", "office_2", "office_3", "office_4"]
    
    for scene in scenes:
        scene_path = os.path.join(root_path, scene)
        if not os.path.exists(scene_path):
            print("Scene {} not found.".format(scene))
            continue
        
        print("Processing scene:{}".format(scene))

        all_depths = sorted(glob(os.path.join(scene_path, "pred_depth", "*")))
    
        for depth_path in tqdm(all_depths):
            
            if depth_path.endswith("png"):     
                depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
            else:
                depth = np.load(depth_path)
                
            index = int(depth_path.split("/")[-1].split(".")[0].split("_")[-1])
            SegResult = np.load(os.path.join(scene_path, "segmentation", "segment{}.npy".format(index)))
    
            pc = inv_project(depth, intrinsc)

            
            fitted_depth = depth.copy()

            for seg_num in range(topk):
                
                mask = SegResult == seg_num

                pc_segments = pc[mask]


                corrected, r2 = quadric_fitting(pc_segments, order=2)
                
                if r2 > Theashold:
                    fitted_depth[mask] = corrected

            # save fitted depth
            save_path = os.path.join(args.output_path, scene, "fitted_depth")
            if not os.path.exists(save_path):
                os.makedirs(save_path)
                
            np.save(os.path.join(save_path, "depth_{}.npy".format(index)), fitted_depth)