


from itertools import count

from torch import nn


import _init_paths
import argparse
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import warnings
warnings.filterwarnings("ignore")
import random
import time
from copy import deepcopy
import math
import numpy as np
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.tensorboard import SummaryWriter
import torch, gc
from datasets.linemod.dataset_cnn_tsn import PoseDataset as PoseDataset_ycb 

from models import CTNetModel
from lib.linemod_evaluator import YCBEval
from lib.utils import setup_logger
from lib.utils import warnup_lr
from lib.utils import post_processing_ycb_quaternion_wi_vote, save_pred_and_gt_json 
import torch
import torchvision
from thop import profile

st_time = time.time()

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=1, help='gpu number')
parser.add_argument('--dataset', type=str, default='linemod', help='ycb or linemod')
parser.add_argument('--dataset_root', type=str, default='',
                    help='dataset root dir (''YCB_Video_Dataset'' or ''Linemod_preprocessed'')')

parser.add_argument('--gpu_number', type=int, default=1, help='gpu number')
parser.add_argument('--batch_size', type=int, default=8, help='batch size') 
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers')
parser.add_argument('--lr', default=0.00005, help='learning rate') 
parser.add_argument('--lr_rate', default=0.1, help='learning rate decay rate')
parser.add_argument('--warnup_iters', default=500, help='learning rate decay rate')
parser.add_argument('--decay_epoch', default=90, help='learning rate decay rate') 
parser.add_argument('--cos', type=int, default=0, help='cosine lr schedule')
parser.add_argument('--noise_trans', default=0.03,
                    help='range of the random noise of translation added to the training data')
parser.add_argument('--nepoch', type=int, default=100, help='max number of epochs to train') 
parser.add_argument('--resume', type=str, default='', help='resume PoseNet model') 
parser.add_argument('--start_epoch', type=int, default=0, help='which epoch to start') 
opt = parser.parse_args()

from thop import clever_format
from thop import profile




def main():
    
    

    print(torch.cuda.is_available())

    torch.backends.cudnn.enabled = True
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    opt.gpu_number = torch.cuda.device_count()
    
    if opt.dataset == 'linemod':
        
        opt.num_objects = 13  
        opt.num_points = 512
        opt.outf = ''
        opt.log_dir = ''
        if os.path.isdir(opt.outf) == False:
            os.mkdir(opt.outf)
        if os.path.isdir(opt.log_dir) == False:
            os.mkdir(opt.log_dir)

    

    print(opt)

    mp.spawn(per_processor, nprocs=opt.gpu_number, args=(opt,))


import os
import torchvision.transforms as transforms
import torchvision.utils as vutils


def save_images(tensor, path, prefix):
    from datetime import datetime
    
    folder_name = datetime.now().strftime("%Y%m%d_%H%M%S")
    folder_path = os.path.join(path, folder_name)
    os.makedirs(folder_path, exist_ok=True)

    
    if tensor.ndimension() == 4:  
        batch_size = tensor.size(0)
        for i in range(batch_size):
            img_tensor = tensor[i]  
            img = transforms.ToPILImage()(img_tensor.cpu())
            img.save(os.path.join(folder_path, f"{prefix}_{i}.png"))
    else:
        img = transforms.ToPILImage()(tensor.cpu())
        img.save(os.path.join(folder_path, f"{prefix}.png"))

def predict(data, estimator, lossor, opt, mode='train'):
    cls_ids = data['class_id'].to(opt.gpu)
    rgb = data['rgb'].to(opt.gpu)
    depth = data['xyz'].to(opt.gpu)
    mask = data['mask'].to(opt.gpu)
    gt_r = data['target_r'].to(opt.gpu)
    gt_t = data['target_t'].to(opt.gpu)

    preds = estimator(rgb, depth, cls_ids)

    loss, loss_dict = lossor(preds, mask, gt_r, gt_t, cls_ids)

    
    
    

    if mode == 'train':
        return loss, loss_dict

    if mode == 'test':
        preds['xyz'] = depth
        res_T = post_processing_ycb_quaternion_wi_vote(preds, opt.sym_list)
        bs, _, _ = res_T.size()

        
        

        res_T = res_T.cpu().numpy()

        tar_T = torch.cat([gt_r, gt_t.unsqueeze(dim=2)], dim=2)
        tar_T = tar_T.cpu().numpy()

        gt_cls = data['class_id'].cpu().numpy().astype(int)

        model_xyz = data['model_xyz'].cpu().numpy()

        rt_list = []
        gt_rt_list = []
        gt_cls_list = []
        model_list = []
        for i in range(bs):
            scale = opt.obj_radius[int(gt_cls[i])]
            res_T[i, :, 3] *= scale
            tar_T[i, :, 3] *= scale
            model_xyz[i] *= scale
            rt_list.append(res_T[i])
            gt_rt_list.append(tar_T[i])
            gt_cls_list.append(gt_cls[i] + 1)
            model_list.append(model_xyz[i])

        return loss, loss_dict, rt_list, gt_rt_list, gt_cls_list, model_list


import torch
import torch.optim as optim






def per_processor(gpu, opt):
    print(1)
    opt.gpu = gpu
    tensorboard_writer = 0
    if gpu == 0:
        tensorboard_writer = SummaryWriter(opt.log_dir)
    import torch, gc
    
    torch.distributed.init_process_group(backend='gloo', init_method='tcp://localhost:23456', rank=gpu, world_size=opt.gpu_number)

    print("init gps:{}".format(gpu))
    torch.cuda.set_device(gpu)

    
    estimator = CTNetModel.CTNet(num_class=opt.num_objects).to(gpu)
    estimator = torch.nn.parallel.DistributedDataParallel(estimator, device_ids=[gpu], output_device=gpu, find_unused_parameters=True)

    
    optimizer = optim.Adam(estimator.parameters(), lr=opt.lr * opt.gpu_number)




    
    if opt.resume != '':
            
            loc = 'cuda:{}'.format(gpu)
            checkpoint = torch.load('{0}/{1}'.format(opt.outf, opt.resume), map_location=loc)
            model_dict = estimator.state_dict()
            same_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in model_dict.keys()}
            model_dict.update(same_dict)
            estimator.load_state_dict(model_dict)
            print("loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch']))

    
    dataset = PoseDataset_ycb('train', opt.num_points, opt.dataset_root, True, opt.noise_trans)
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=False,
                                             num_workers=opt.workers, pin_memory=True, sampler=sampler)
    if gpu == 0:
        test_set = PoseDataset_ycb('test', opt.num_points, opt.dataset_root, False, opt.noise_trans)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=opt.batch_size*2, shuffle=False,
                                                 num_workers=opt.workers*2, pin_memory=True)


    opt.sym_list = dataset.get_sym_list()
    opt.num_points_mesh = dataset.get_num_points_mesh()
    opt.obj_radius = dataset.obj_radius

    
    lossor = CTNetModel.get_loss(dataset).to(gpu)




    
    import torch, gc
    tensorboard_loss_list = []
    tensorboard_test_list = []
    for epoch in range(opt.start_epoch, opt.nepoch):
        sampler.set_epoch(epoch)
        opt.cur_epoch = epoch

        
        print('>>>>>>>>>>>train>>>>>>>>>>>')
        train(dataloader, estimator, lossor, optimizer, epoch, tensorboard_writer, tensorboard_loss_list, opt)


        
        
        if gpu == 0:
            print('>>>>>>>>>>>save checkpoint>>>>>>>>>>')
            torch.save({
                'epoch': epoch + 1,
                'state_dict': estimator.state_dict()},
                '{}/checkpoint_{:04d}.pth.tar'.format(opt.outf, epoch))
        
        if gpu == 0:
            print('>>>>>>>>>>>test>>>>>>>>>>>')
            test(test_loader, estimator, lossor, epoch, tensorboard_writer, tensorboard_test_list, opt)
            


if __name__ == '__main__':

    main()

def train(train_loader, estimator, lossor, optimizer, epoch, tensorboard_writer, tensorboard_loss_list, opt):
    if opt.gpu == 0:
        train_loss_list = []
        logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))

    estimator.train()
    optimizer.zero_grad()
    

    i = 0
    for data in train_loader:
        i += 1
        
        iter_th = epoch * len(train_loader) + i
        cur_lr = adjust_learning_rate(optimizer, epoch, iter_th, opt)
        loss, loss_dict = predict(data, estimator, lossor, opt, mode='train')
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        
        if opt.gpu == 0:
            train_loss_list.append(loss_dict)
            log_function(train_loss_list, logger, epoch, i, cur_lr)

            if len(train_loss_list) % 50 == 0:
                l_dict = deepcopy(train_loss_list[-50])
                for ld in train_loss_list[-49:]:
                    for key in ld:
                        l_dict[key] += ld[key]
                for key in l_dict:
                    l_dict[key] = l_dict[key] / 50.0

                tensorboard_loss_list.append(l_dict)
                draw_loss_list('train', tensorboard_loss_list, tensorboard_writer)

def test(test_loader, estimator, lossor, epoch, tensorboard_writer, tensorboard_test_list, opt):
    if opt.gpu == 0:
        test_loss_list = []
        logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'test_%d_log.txt' % epoch))
        
        ycb_evaluator = YCBEval()

    estimator.eval()

    with torch.no_grad():
        i = 0
        total_rt_list = []
        total_gt_list = []
        total_cls_list = []
        for data in test_loader:
            i += 1
            _, test_loss_dict, rt_list, gt_rt_list, gt_cls_list, model_list  = predict(data, estimator, lossor, opt, mode='test')
            total_rt_list += rt_list
            total_gt_list += gt_rt_list
            total_cls_list += gt_cls_list

            
            ycb_evaluator.eval_pose_parallel(rt_list, gt_cls_list, gt_rt_list, gt_cls_list, model_list)

            
            if opt.gpu == 0:
                
                test_loss_list.append(test_loss_dict)
                log_function(test_loss_list, logger, epoch, i, opt.lr)

        save_pred_and_gt_json(total_rt_list, total_gt_list, total_cls_list, opt.log_dir)

        
        if opt.gpu == 0:
            
            cur_eval_info_dict = ycb_evaluator.cal_auc()
            
            l = deepcopy(test_loss_list[0])
            for ld in test_loss_list[1:]:
                for key in ld:
                    l[key] += ld[key]
            for key in l:
                l[key] = l[key] / len(test_loss_list)

            
            l['add_s_2cm'] = cur_eval_info_dict['add_s_2cm']
            tensorboard_test_list.append(l)
            draw_loss_list('test', tensorboard_test_list, tensorboard_writer)

            
            log_tmp = 'TEST ENDING: '
            for key in l:
                log_tmp = log_tmp + ' {}:{:.4f}'.format(key, l[key])
            logger.info(log_tmp)
            for key in cur_eval_info_dict:
                log_tmp = log_tmp + ' {}:{:.2f}'.format(key, cur_eval_info_dict[key])
                logger.info(log_tmp)

