# Copyright (c) Facebook, Inc. and its affiliates.
import torch
import torch.nn.functional as F
import datetime
import logging
import math
import time
import sys

from clip import clip_pc_img, clip_3_branch, clip_3_branch_dist_aware

from torch.distributed.distributed_c10d import reduce
from utils.ap_calculator import APCalculator
from utils.misc import SmoothedValue
from utils.dist import (
    all_gather_dict,
    all_reduce_average,
    is_primary,
    reduce_dict,
    barrier,
)
from models.DETR.util.misc import NestedTensor


def compute_learning_rate(args, curr_epoch_normalized):
    assert curr_epoch_normalized <= 1.0 and curr_epoch_normalized >= 0.0
    if (
        curr_epoch_normalized <= (args.warm_lr_epochs / args.max_epoch)
        and args.warm_lr_epochs > 0
    ):
        # Linear Warmup
        curr_lr = args.warm_lr + curr_epoch_normalized * args.max_epoch * (
            (args.base_lr - args.warm_lr) / args.warm_lr_epochs
        )
    else:
        # Cosine Learning Rate Schedule
        curr_lr = args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (
            1 + math.cos(math.pi * curr_epoch_normalized)
        )
    return curr_lr


def adjust_learning_rate(args, optimizer, curr_epoch):
    curr_lr = compute_learning_rate(args, curr_epoch)
    for param_group in optimizer.param_groups:
        param_group["lr"] = curr_lr
    return curr_lr


def imagenet_loss(predict, ground_truth, dataset_config):
    # Select the max size obj
    obj_w = predict["pred_boxes"][:, :, 2]
    obj_h = predict["pred_boxes"][:, :, 3]
    obj_size = obj_w * obj_h
    selected_ind = torch.argmax(obj_size, dim=1)
    
    # Select the max objectness obj
    '''
    objectness = 1 - F.softmax(predict["pred_logits"], dim=-1)[:,:,-1]
    selected_ind = torch.argmax(objectness, dim=1)
    '''
    
    gt_selected = []
    gt_selected_ind = []
    
    gt_imagenet = []
    #gt_imagenet_ind = []
    
    #print(len(dataset_config.type2class.keys()))
    
    # Gather gt labels
    batch_size = predict["pred_logits"].shape[0]
    for cur_bs in range(batch_size):
    	cur_label = ground_truth[cur_bs]["labels"]
    	
    	if cur_label < len(dataset_config.type2class.keys()):
    		gt_selected.append(cur_label)
    		gt_selected_ind.append(cur_bs)
    	#else:
    	gt_imagenet.append(cur_label)
    	#gt_imagenet_ind.append(cur_bs)

    if len(gt_selected_ind) > 0:
    	gt_selected = torch.hstack(gt_selected)
    
    gt_imagenet = torch.hstack(gt_imagenet)
    #print(gt_selected)
    #print(gt_imagenet)
    #print(gt_selected_ind)
    
    # Gather the pred logits
    selected_logits = []
    imagenet_logits = []
    logits_no_use = []
    
    for cur_bs in range(batch_size):
    	if cur_bs in gt_selected_ind:
    		selected_logits.append(predict["pred_logits"][cur_bs, selected_ind[cur_bs], :])
    	else:
    		logits_no_use.append(predict["pred_logits"][cur_bs, selected_ind[cur_bs], :])
    		
    	
    	imagenet_logits.append(predict["imagenet_logits"][cur_bs, selected_ind[cur_bs], :])
    
    
    #print(gt_selected_ind)
    if len(gt_selected_ind) > 0:
    	selected_logits = torch.vstack(selected_logits)[:, :-1]
    if len(logits_no_use) > 0 :
    	logits_no_use = torch.vstack(logits_no_use)[:, :-1]
    	
    
    imagenet_logits = torch.vstack(imagenet_logits)
    
    loss = F.cross_entropy(imagenet_logits, gt_imagenet)
    
    if len(gt_selected_ind) > 0:
    	loss += F.cross_entropy(selected_logits, gt_selected)
    else:
    	loss += 0 * torch.sum(logits_no_use)
    	
    return loss

def train_one_epoch(
    args,
    curr_epoch,
    model,
    optimizer,
    criterion,
    criterion_img,
    dataset_config,
    dataset_loader,
    logger,
):

    ap_calculator = APCalculator(
        dataset_config=dataset_config,
        ap_iou_thresh=[0.25, 0.5],
        class2type_map=dataset_config.class2type,
        exact_eval=False,
    )

    curr_iter = curr_epoch * len(dataset_loader)
    max_iters = args.max_epoch * len(dataset_loader)
    net_device = next(model.parameters()).device

    time_delta = SmoothedValue(window_size=10)
    loss_avg = SmoothedValue(window_size=10)
    img_loss_avg = SmoothedValue(window_size=10)
    imagenet_loss_avg = SmoothedValue(window_size=10)
    clip_loss_avg = SmoothedValue(window_size=10)

    model.train()
    barrier()

    for batch_idx, batch_data_label in enumerate(dataset_loader):
        curr_time = time.time()
        curr_lr = adjust_learning_rate(args, optimizer, curr_iter / max_iters)
        for key in batch_data_label:
            batch_data_label[key] = batch_data_label[key].to(net_device)

        # Construct Image Branch Input
        img_input = NestedTensor(batch_data_label["image"], batch_data_label["mask"])
        
        # Construct Image Ground Truth
        img_ground_truth = []
        batch_size = batch_data_label["image"].shape[0]
        for ind in range(batch_size):
        	cur_gt = {}
        	cur_boxes_num = batch_data_label["bbox_num"][ind]
        	cur_gt["boxes"] = batch_data_label["bboxes_2d"][ind, :cur_boxes_num, :]
        	cur_gt["labels"] = batch_data_label["bboxes_2d_label"][ind, :cur_boxes_num]
        	img_ground_truth.append(cur_gt)
        
        # Forward pass
        optimizer.zero_grad()
        pc_input = {
            "point_clouds": batch_data_label["point_clouds"],
            "point_cloud_dims_min": batch_data_label["point_cloud_dims_min"],
            "point_cloud_dims_max": batch_data_label["point_cloud_dims_max"],
        }
        pc_output, img_output = model(pc_input, img_input)
        #print(img_output.keys())

        # Compute loss
        loss_pc, loss_dict = criterion(pc_output, batch_data_label)

        loss_reduced = all_reduce_average(loss_pc)
        loss_dict_reduced = reduce_dict(loss_dict)

        if not math.isfinite(loss_reduced.item()):
            logging.info(f"Loss in not finite. Training will be stopped.")
            sys.exit(1)

        
        # Compute img_loss
        img_loss_dict = criterion_img(img_output, img_ground_truth)
        weight_dict = criterion_img.weight_dict
        loss_img = sum(img_loss_dict[k] * weight_dict[k] for k in img_loss_dict.keys() if k in weight_dict)
        

        # Construct Image ImageNet Branch Input
        imagenet_img = batch_data_label["image_imagenet"]
        batch_data_label["image_imagenet"] = torch.reshape(imagenet_img, (-1, imagenet_img.shape[-3], imagenet_img.shape[-2], imagenet_img.shape[-1]))
        imagenet_mask = batch_data_label["mask_imagenet"]
        batch_data_label["mask_imagenet"] = torch.reshape(imagenet_mask, (-1, imagenet_mask.shape[-2], imagenet_mask.shape[-1]))
        bboxes_2d_imagenet = batch_data_label["bboxes_2d_imagenet"]
        batch_data_label["bboxes_2d_imagenet"] = torch.reshape(bboxes_2d_imagenet, (-1, bboxes_2d_imagenet.shape[-2], bboxes_2d_imagenet.shape[-1]))
        bboxes_2d_label_imagenet = batch_data_label["bboxes_2d_label_imagenet"]
        batch_data_label["bboxes_2d_label_imagenet"] = torch.reshape(bboxes_2d_label_imagenet, (-1, bboxes_2d_label_imagenet.shape[-1]))
        bbox_num_imagenet = batch_data_label["bbox_num_imagenet"]
        batch_data_label["bbox_num_imagenet"] = bbox_num_imagenet.flatten()
        
        '''
        print(batch_data_label["image_imagenet"].shape)
        print(batch_data_label["mask_imagenet"].shape)
        print(batch_data_label["bboxes_2d_imagenet"].shape)
        print(batch_data_label["bboxes_2d_label_imagenet"].shape)
        print(batch_data_label["bbox_num_imagenet"].shape)
        '''
        
        
        imagenet_input = NestedTensor(batch_data_label["image_imagenet"], batch_data_label["mask_imagenet"])
        # Construct Image ImageNet Ground Truth
        imagenet_ground_truth = []
        batch_size = batch_data_label["image_imagenet"].shape[0]
        for ind in range(batch_size):
            cur_gt = {}
            cur_boxes_num = batch_data_label["bbox_num_imagenet"][ind]
            cur_gt["boxes"] = batch_data_label["bboxes_2d_imagenet"][ind, :cur_boxes_num, :]
            cur_gt["labels"] = batch_data_label["bboxes_2d_label_imagenet"][ind, :cur_boxes_num]
            imagenet_ground_truth.append(cur_gt)
        pc_input = None
        _, imagenet_output = model(pc_input, imagenet_input, use_1k_header=True)
        
        # Compute imagenet_loss
        loss_imagenet = imagenet_loss(imagenet_output, imagenet_ground_truth, dataset_config)
        
        # Compute CLIP loss between 3 branch
        # Class CLIP
        #clip_loss = clip_3_branch(args, img_output, pc_output, img_ground_truth, batch_data_label, imagenet_output)
        
        # Adapative CLIP
        clip_loss = clip_3_branch_dist_aware(args, img_output, pc_output, img_ground_truth, batch_data_label, imagenet_output)
        
        final_loss = loss_pc + loss_img + loss_imagenet + 10 * clip_loss
        
        final_loss.backward()
        
        if args.clip_gradient > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradient)
            #print("here 2")
        optimizer.step()

        if curr_iter % args.log_metrics_every == 0:
            # This step is slow. AP is computed approximately and locally during training.
            # It will gather outputs and ground truth across all ranks.
            # It is memory intensive as point_cloud ground truth is a large tensor.
            # If GPU memory is not an issue, uncomment the following lines.
            # outputs["outputs"] = all_gather_dict(outputs["outputs"])
            # batch_data_label = all_gather_dict(batch_data_label)
            ap_calculator.step_meter(pc_output, batch_data_label)

        time_delta.update(time.time() - curr_time)
        loss_avg.update(loss_reduced.item())
        img_loss_avg.update(loss_img.item())
        imagenet_loss_avg.update(loss_imagenet.item())
        clip_loss_avg.update(clip_loss.item())

        # logging
        if is_primary() and curr_iter % args.log_every == 0:
            mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
            eta_seconds = (max_iters - curr_iter) * time_delta.avg
            eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
            print(
                f"Epoch [{curr_epoch}/{args.max_epoch}]; Iter [{curr_iter}/{max_iters}]; Clip Loss {clip_loss_avg.avg:0.2f}; Loss {loss_avg.avg:0.2f}; Img Loss {img_loss_avg.avg:0.2f}; ImageNet Loss {imagenet_loss_avg.avg:0.2f}; LR {curr_lr:0.2e}; Iter time {time_delta.avg:0.2f}; ETA {eta_str}; Mem {mem_mb:0.2f}MB"
            )
            logger.log_scalars(loss_dict_reduced, curr_iter, prefix="Train_details/")

            train_dict = {}
            train_dict["lr"] = curr_lr
            train_dict["memory"] = mem_mb
            train_dict["loss"] = loss_avg.avg
            train_dict["batch_time"] = time_delta.avg
            train_dict["img_loss"] = img_loss_avg.avg
            train_dict["imagenet_loss"] = imagenet_loss_avg.avg
            train_dict["clip_loss"] = clip_loss_avg.avg
            
            
            logger.log_scalars(train_dict, curr_iter, prefix="Train/")

        curr_iter += 1
        barrier()

    return ap_calculator


@torch.no_grad()
def evaluate(
    args,
    curr_epoch,
    model,
    criterion,
    dataset_config,
    dataset_loader,
    logger,
    curr_train_iter,
):

    # ap calculator is exact for evaluation. This is slower than the ap calculator used during training.
    ap_calculator = APCalculator(
        dataset_config=dataset_config,
        ap_iou_thresh=[0.25, 0.5],
        class2type_map=dataset_config.class2type,
        exact_eval=True,
    )

    curr_iter = 0
    net_device = next(model.parameters()).device
    num_batches = len(dataset_loader)

    time_delta = SmoothedValue(window_size=10)
    loss_avg = SmoothedValue(window_size=10)
    model.eval()
    barrier()
    epoch_str = f"[{curr_epoch}/{args.max_epoch}]" if curr_epoch > 0 else ""

    for batch_idx, batch_data_label in enumerate(dataset_loader):
        curr_time = time.time()
        for key in batch_data_label:
            batch_data_label[key] = batch_data_label[key].to(net_device)

        # Construct Image Branch Input
        img_input = NestedTensor(batch_data_label["image"], batch_data_label["mask"])
        
        # Construct Image Ground Truth
        img_ground_truth = []
        batch_size = batch_data_label["image"].shape[0]
        for ind in range(batch_size):
        	cur_gt = {}
        	#print(batch_data_label["bboxes_2d"][ind,:,:])
        	#print(batch_data_label["bboxes_2d_label"][ind,:])
        	cur_boxes_num = batch_data_label["bbox_num"][ind]
        	cur_gt["boxes"] = batch_data_label["bboxes_2d"][ind, :cur_boxes_num, :]
        	cur_gt["labels"] = batch_data_label["bboxes_2d_label"][ind, :cur_boxes_num]
        	#print(cur_gt["boxes"].shape)
        	#print(cur_gt["labels"])
        	img_ground_truth.append(cur_gt)
        	
        	
        pc_input = {
            "point_clouds": batch_data_label["point_clouds"],
            "point_cloud_dims_min": batch_data_label["point_cloud_dims_min"],
            "point_cloud_dims_max": batch_data_label["point_cloud_dims_max"],
        }
        
        outputs, _ = model(pc_input, img_input)

        # Compute loss
        loss_str = ""
        if criterion is not None:
            loss, loss_dict = criterion(outputs, batch_data_label)

            loss_reduced = all_reduce_average(loss)
            loss_dict_reduced = reduce_dict(loss_dict)
            loss_avg.update(loss_reduced.item())
            loss_str = f"Loss {loss_avg.avg:0.2f};"

        # Memory intensive as it gathers point cloud GT tensor across all ranks
        outputs["outputs"] = all_gather_dict(outputs["outputs"])
        batch_data_label = all_gather_dict(batch_data_label)
        ap_calculator.step_meter(outputs, batch_data_label)
        time_delta.update(time.time() - curr_time)
        if is_primary() and curr_iter % args.log_every == 0:
            mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
            print(
                f"Evaluate {epoch_str}; Batch [{curr_iter}/{num_batches}]; {loss_str} Iter time {time_delta.avg:0.2f}; Mem {mem_mb:0.2f}MB"
            )

            test_dict = {}
            test_dict["memory"] = mem_mb
            test_dict["batch_time"] = time_delta.avg
            if criterion is not None:
                test_dict["loss"] = loss_avg.avg
        curr_iter += 1
        barrier()
    if is_primary():
        if criterion is not None:
            logger.log_scalars(
                loss_dict_reduced, curr_train_iter, prefix="Test_details/"
            )
        logger.log_scalars(test_dict, curr_train_iter, prefix="Test/")

    return ap_calculator
