import torch
import torch.nn as nn
import os
import json
from tools import builder
from utils import misc, dist_utils
import time
from utils.logger import *

import cv2
import numpy as np


def test_net(args, config):
    logger = get_logger(args.log_name)
    print_log('Tester start ... ', logger = logger)
    _, test_dataloader = builder.dataset_builder(args, config.dataset.train)

    base_model = builder.model_builder(config.model)
    base_model.load_model_from_ckpt(args.ckpts)
    #builder.load_model(base_model, args.ckpts, logger = logger)

    if args.use_gpu:
        base_model.to(args.local_rank)

    #  DDP
    if args.distributed:
        raise NotImplementedError()

    test(base_model, test_dataloader, args, config, logger=logger)


# visualization
def test(base_model, test_dataloader, args, config, logger = None):

    base_model.eval()  # set model to eval mode
    target = './vis'
    useful_cate = [
        "02691156", #plane
        "04379243",  #table
        "03790512", #motorbike
        "03948459", #pistol
        "03642806", #laptop
        "03467517",     #guitar
        "03261776", #earphone
        "03001627", #chair
        "02958343", #car
        "04090263", #rifle
        "03759954", # microphone
    ]
    with torch.no_grad():
        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
            # import pdb; pdb.set_trace()
            points = data[0].cuda()
            label = data[1].cuda()
            raw_points, centers, scores = base_model(points, vis=True)
            
            new_score = scores.unsqueeze(-1).repeat(1, 1, raw_points.size(2))  # [B, G, M]
            new_score = new_score.unsqueeze(-1)  # [B, G, M, 1]
            
    
            ##vis_part
            # [B, G, M, 3]  [B, G, 3]
            score_points = torch.cat((raw_points, new_score), dim=3)  # [B, G, M, 4]
            #new_raw_points = raw_points[0].reshape(-1, 3).cpu().numpy()
            #new_score_points = score_points[0].reshape(-1, 4).cpu().numpy()
            #raw_points = misc.get_ptcloud_img(new_raw_points, roll=30, pitch=-45)
            #score_points = misc.get_score_img(new_score_points, roll=30, pitch=-45)
            
            for b in range(data[0].size(0)):
                #final_image = []
                new_raw_points = raw_points[b].reshape(-1, 3).cpu().numpy()
                #print(new_raw_points.shape)
                new_score_points = score_points[b].reshape(-1, 4).cpu().numpy()
                #print(new_score_points.shape)
                final_raw_points = misc.get_ptcloud_img(new_raw_points, roll=30, pitch=-40)
                final_score_points = misc.get_score_img(new_score_points, roll=30, pitch=-40)
            
                raw_image = final_raw_points[150:650, 150:675, :]
                score_image = final_score_points[150:650, 150:675, :]
                raw_image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
                score_image = cv2.cvtColor(score_image, cv2.COLOR_BGR2RGB)
                #score_image = cv2.applyColorMap(score_image, cv2.COLORMAP_AUTUMN)
                heat_image = cv2.applyColorMap(score_image, cv2.COLORMAP_JET)
                for x in range(0,heat_image.shape[0]):
                    for y in range(0,heat_image.shape[1]):
                        if (heat_image[x,y][0] == 128)&(heat_image[x,y][1] == 0)&(heat_image[x,y][2] == 0):
                            heat_image[x, y] = (255, 255, 255)
                #mix_image = heat_image*0.7 + raw_image*0.3
                #final_image.append(raw_image)
                #final_image.append(score_image)
                #final_image.append(heat_image)
                #final_image.append(mix_image)
                #img = np.concatenate(final_image, axis=1)
                label_num = label[b]
                raw_dir_path = os.path.join(f'./visual/ModelNet40_train/PointACL-MAE/raw_visual', f'{label_num}')
                heat_dir_path = os.path.join(f'./visual/ModelNet40_train/PointACL-MAE/score_visual', f'{label_num}')
                #raw_dir_path = os.path.join(f'./visual/Scan_objONLY/PointACL-MAE/raw_visual', f'{label_num}')
                #heat_dir_path = os.path.join(f'./visual/Scan_objONLY/PointACL-MAE/score_visual', f'{label_num}')
                if not (os.path.exists(raw_dir_path)):
                    os.makedirs(raw_dir_path)
                raw_img_path = os.path.join(raw_dir_path, f'plot_{idx}_{b}.jpg')            
                if not (os.path.exists(heat_dir_path)):
                    os.makedirs(heat_dir_path)
                heat_img_path = os.path.join(heat_dir_path, f'plot_{idx}_{b}.jpg')
                cv2.imwrite(raw_img_path, raw_image)
                cv2.imwrite(heat_img_path, heat_image)
            
            
            """
            final_image = []
            final_image.append(raw_points[150:650, 150:675, :])
            final_image.append(score_points[150:650, 150:675, :])
    
            img = np.concatenate(final_image, axis=1)
            img_path = os.path.join(f'./visual/test', f'plot_{idx}_0.jpg')
            cv2.imwrite(img_path, img)
            """
            
            
            """
            if  taxonomy_ids[0] not in useful_cate:
                continue
            if taxonomy_ids[0] == "02691156":
                a, b= 90, 135
            elif taxonomy_ids[0] == "04379243":
                a, b = 30, 30
            elif taxonomy_ids[0] == "03642806":
                a, b = 30, -45
            elif taxonomy_ids[0] == "03467517":
                a, b = 0, 90
            elif taxonomy_ids[0] == "03261776":
                a, b = 0, 75
            elif taxonomy_ids[0] == "03001627":
                a, b = 30, -45
            else:
                a, b = 0, 0


            dataset_name = config.dataset.test._base_.NAME
            if dataset_name == 'ShapeNet':
                points = data.cuda()
            else:
                raise NotImplementedError(f'Train phase do not support {dataset_name}')

            # dense_points, vis_points = base_model(points, vis=True)
            dense_points, vis_points, centers= base_model(points, vis=True)
            final_image = []
            data_path = f'./vis/{taxonomy_ids[0]}_{idx}'
            if not os.path.exists(data_path):
                os.makedirs(data_path)

            points = points.squeeze().detach().cpu().numpy()
            np.savetxt(os.path.join(data_path,'gt.txt'), points, delimiter=';')
            points = misc.get_ptcloud_img(points,a,b)
            final_image.append(points[150:650,150:675,:])

            # centers = centers.squeeze().detach().cpu().numpy()
            # np.savetxt(os.path.join(data_path,'center.txt'), centers, delimiter=';')
            # centers = misc.get_ptcloud_img(centers)
            # final_image.append(centers)

            vis_points = vis_points.squeeze().detach().cpu().numpy()
            np.savetxt(os.path.join(data_path, 'vis.txt'), vis_points, delimiter=';')
            vis_points = misc.get_ptcloud_img(vis_points,a,b)

            final_image.append(vis_points[150:650,150:675,:])

            dense_points = dense_points.squeeze().detach().cpu().numpy()
            np.savetxt(os.path.join(data_path,'dense_points.txt'), dense_points, delimiter=';')
            dense_points = misc.get_ptcloud_img(dense_points,a,b)
            final_image.append(dense_points[150:650,150:675,:])

            img = np.concatenate(final_image, axis=1)
            img_path = os.path.join(data_path, f'plot.jpg')
            cv2.imwrite(img_path, img)

            if idx > 1500:
                break
            """

        return
