

import os
import torch
import numpy as np
from scene import Scene
from tqdm import tqdm
from random import randint
from os import makedirs
from gaussian_renderer import render, render_jac
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, OptimizationParams, get_combined_args
from gaussian_renderer import GaussianModel
from utils.loss_utils import l1_loss, ssim
from utils.image_utils import psnr
import wandb
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F
from PIL import Image
import math



def shaping(opt, dataset : ModelParams, iterations : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
    gaussians = GaussianModel(dataset)
    scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False)
    
    gaussians.finetune_setup()

    bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    viewpoint_stack = scene.getTrainCameras().copy()


    #pre-process
    views_len = len(viewpoint_stack)
    gs_num = gaussians.get_xyz.shape[0]
    batch_size = 512

    for iteration in tqdm(range(iterations), desc="Training", unit="epoch"):
        if not viewpoint_stack:
            viewpoint_stack = scene.getTrainCameras().copy()
        viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
        is_zero = torch.all(viewpoint_cam.obj_mask == 0).item()
        if is_zero:
            continue
        
        
        obj_mask = viewpoint_cam.obj_mask.reshape(-1)
        
        kinds = torch.unique(obj_mask)
        sampled_cnt = 0
        random_idx = []
        sampled_per = int(batch_size / len(kinds))
        while (sampled_cnt != batch_size):
            for idx, v in enumerate(kinds):
                indices_v = (obj_mask == v).nonzero().squeeze()
                cur_per = min(batch_size - sampled_cnt, sampled_per)
                if cur_per > torch.numel(indices_v):
                    cur_per = torch.numel(indices_v)
                if indices_v.shape == torch.Size([]):
                    indices_v = indices_v.view(1)
                random_idx.extend(indices_v[np.random.choice(indices_v.shape[0], cur_per, replace=False)].cpu().numpy())
                sampled_cnt += cur_per
                if sampled_cnt == batch_size:
                    break
        random_idx = np.array(random_idx)
        np.random.shuffle(random_idx)

        reg_idx = np.random.choice(gs_num, batch_size, replace=False)

        mask_img = torch.zeros_like(viewpoint_cam.obj_mask).reshape(-1)
        mask_img[random_idx] = 1
        render_pkg = render_jac(viewpoint_cam, gaussians, pipeline, background, mask_img=mask_img)
        image, jacobians, acc_jacobians = render_pkg["render"], render_pkg["jacobians"], render_pkg["acc_jacobians"]
        gt_image = viewpoint_cam.original_image.cuda()

        Ll1 = l1_loss(image, gt_image)
        img_loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))

        jacobians_sel = acc_jacobians[random_idx]
        gt_features = obj_mask[random_idx]

        jacobians_sel = jacobians_sel / (jacobians_sel.norm(dim=1) + 1e-7)
        res = torch.mm(jacobians_sel, jacobians_sel.t())
        labels_eq = (gt_features.unsqueeze(1) == gt_features.unsqueeze(0)).float()
        E_pos = math.log(2.) - F.softplus(res * labels_eq)
        E_neg = F.softplus(-(res * (1 - labels_eq))) + (res * (1 - labels_eq)) - math.log(2.)
        mi_loss = E_neg.sum() / (1 - labels_eq).sum() - E_pos.sum() / (labels_eq).sum()
        


        sampled_xyz = gaussians.get_xyz[reg_idx]
        sampled_grads = jacobians[reg_idx]
        valid_xyz = gaussians.get_xyz
        dists = torch.cdist(sampled_xyz, valid_xyz)
        _, neighbor_indices_tensor = dists.topk(5, largest=False)
        neighbor_grads = jacobians[neighbor_indices_tensor]
        exp_grads = sampled_grads.unsqueeze(1).expand(-1, 5, -1)
        reg_cos = F.cosine_similarity(exp_grads, neighbor_grads, dim = 2)
        reg_cos = (1. - reg_cos).mean()
        

        loss = img_loss + opt.lambda_mi * mi_loss +  0.1 * reg_cos

        loss.backward()

        gaussians.optimizer_net.step()
        gaussians.optimizer_net.zero_grad(set_to_none=True)
        gaussians.scheduler_net.step()


    saved_path = os.path.join(dataset.model_path, 'tuned_models')
    if not os.path.exists(saved_path):
        os.makedirs(saved_path)
    torch.save(torch.nn.ModuleList([gaussians.recolor, gaussians.mlp1, gaussians.mlp2]).state_dict(), os.path.join(saved_path, "point_cloud{}.pth".format(iterations))) 

    tuning_report(iteration, l1_loss, scene, render, (pipeline, background))

def tuning_report(iteration, l1_loss, scene : Scene, renderFunc, renderArgs):

    torch.cuda.empty_cache()
    validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 
                            {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})

    scene.gaussians.precompute()
    for config in validation_configs:
        if config['cameras'] and len(config['cameras']) > 0:
            l1_test = 0.0
            psnr_test = 0.0
            for idx, viewpoint in enumerate(config['cameras']):
                image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
                gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)

                l1_test += l1_loss(image, gt_image).mean().double()
                psnr_test += psnr(image, gt_image).mean().double()
            psnr_test /= len(config['cameras'])
            l1_test /= len(config['cameras'])          
            print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))

    torch.cuda.empty_cache()


if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    op = OptimizationParams(parser)

    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    # parser.add_argument("--sam", action="store_true")

    args = get_combined_args(parser)
    print("Shaping " + args.model_path)

    # Initialize system state (RNG)
    safe_state(args.quiet)
    
    shaping(op.extract(args), model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)