from dataloaders.dataloader import TextDataLoader
import torch
import torch.nn as nn
import argparse
import sys
import os
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from tensorboardX import SummaryWriter
from models.losses import *
import time
from PIL import Image

from datetime import datetime
from models.text_classifier import Classifier
from models.image_classifier import ImageClassifier
from models.image_models import *
from models.model import *
from utils import post_process_depth, flip_lr
def convert_arg_line_to_args(arg_line):
    for arg in arg_line.split():
        if not arg.strip():
            continue
        yield arg

parser = argparse.ArgumentParser(description='BTS PyTorch implementation.', fromfile_prefix_chars='@')
parser.convert_arg_line_to_args = convert_arg_line_to_args

parser.add_argument('--mode',                                       type=str,   help='train or test', default='train')
parser.add_argument('--model_name',                                 type=str,   help='model name', default='bts_eigen_v2')
parser.add_argument('--main_path',                                  type=str,   help='main path of data', required=True)
parser.add_argument('--no_radar',                                               help='if set, train the text+image model', action='store_true')

parser.add_argument('--validation_image_path',                      type=str,   help='path of validation image', required=True)
parser.add_argument('--validation_text_feature_general_path',       type=str,   help='path of training radar', required=True)
parser.add_argument('--validation_text_feature_left_path',          type=str,   help='path of training radar', required=True)
parser.add_argument('--validation_text_feature_mid_left_path',      type=str,   help='path of training radar', required=True)
parser.add_argument('--validation_text_feature_mid_right_path',     type=str,   help='path of training radar', required=True)
parser.add_argument('--validation_text_feature_right_path',         type=str,   help='path of training radar', required=True)
parser.add_argument('--validation_weather_condition_path',          type=str,   help='path of validation radar', required=True)
parser.add_argument('--validation_ground_truth_path',               type=str,   help='path of testing ground truth', required=True)
parser.add_argument('--validation_radar_path',                      type=str,   help='path of testing radar', required=True)

parser.add_argument('--k',                                          type=int,   help='k nearest neighbor', default=4)
parser.add_argument('--encoder_radar',                              type=str,   help='type of encoder of radar channels, resnet34', default='resnet18')
parser.add_argument('--radar_input_channels',                       type=int,   help='number of input radar channels', default=4)
parser.add_argument('--radar_gcn_channel_in',                       type=int,   help='input channels',  default=6)
parser.add_argument('--encoder',                                    type=str,   help='type of encoder', default='resnet34_bts')
parser.add_argument('--n_filters_decoder',                          type=int,   help='number of decoder filters', default=[256, 256, 128, 64, 32])
parser.add_argument('--input_height',                               type=int,   help='input height', default=352)
parser.add_argument('--input_width',                                type=int,   help='input width',  default=704)
parser.add_argument('--max_depth',                                  type=float, help='maximum depth in estimation', default=100)

parser.add_argument('--fuse',                                       type=str,   help='fusion of radar & image', default='sim_base')
parser.add_argument('--text_fuse',                                  type=str,   help='fusion of text & image', default='add')
parser.add_argument('--text_encode_mode',                           type=str,   help='text feature processing mode', default='average')
parser.add_argument('--text_hidden_dim',                            type=int,   help='hidden dimension of text feature',  default=128)
parser.add_argument('--use_img_feat',                                           help='if set, during weather classification, also use img feature', action='store_true')
parser.add_argument('--point_hidden_dim',                           type=int,   help='hidden dimension of radar point feature',  default=128)
parser.add_argument('--min_depth_eval',                             type=float, help='minimum depth for evaluation', default=1e-3)
parser.add_argument('--max_depth_eval',                             type=float, help='maximum depth for evaluation', default=80)
parser.add_argument('--checkpoint_path',                            type=str,   help='path to a checkpoint to load', default='')

parser.add_argument('--num_threads',                                type=int,   help='number of threads to use for data loading', default=1)
parser.add_argument('--store_prediction',                                       help='if set, store the predicted depth and radar confidence', action='store_true')


if sys.argv.__len__() == 2:
    arg_filename_with_prefix = '@' + sys.argv[1]
    args = parser.parse_args([arg_filename_with_prefix])
else:
    args = parser.parse_args()

model_dir = os.path.dirname(args.checkpoint_path)
sys.path.append(model_dir)
if args.store_prediction:
    save_dir = './eval_result/' + model_dir.split('/')[-1].split('_')[0]
    pred_depth_dir = save_dir + '/pred_depth'

    if not os.path.exists(save_dir):
        try:
            os.makedirs(pred_depth_dir)

        except Exception:
            pass

def compute_errors(gt, pred):
    thresh = np.maximum((gt / pred), (pred / gt))
    d1 = (thresh < 1.25).mean()
    d2 = (thresh < 1.25 ** 2).mean()
    d3 = (thresh < 1.25 ** 3).mean()

    mae = np.mean(np.abs(gt - pred))
    rmse = (gt - pred) ** 2
    rmse = np.sqrt(rmse.mean())

    rmse_log = (np.log(gt) - np.log(pred)) ** 2
    rmse_log = np.sqrt(rmse_log.mean())

    abs_rel = np.mean(np.abs(gt - pred) / gt)
    sq_rel = np.mean(((gt - pred)**2) / gt)

    err = np.log(pred) - np.log(gt)
    silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100

    err = np.abs(np.log10(pred) - np.log10(gt))
    log10 = np.mean(err)

    return [silog, log10, abs_rel, sq_rel, rmse, rmse_log, d1, d2, d3, mae]

def get_num_lines(file_path):
    f = open(file_path, 'r')
    lines = f.readlines()
    f.close()
    return len(lines)

def test(params):
    dataloader = TextDataLoader(args, 'test')

    model = TRIDENT(args)
    model.decoder.apply(weights_init_xavier)
    

    model = torch.nn.DataParallel(model)
    checkpoint = torch.load(args.checkpoint_path, map_location=args.device)
    global_step = checkpoint['global_step']
    model.load_state_dict(checkpoint['model'])
    model.eval()
    model.cuda()
    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    print("Total number of parameters: {}".format(num_params))
    eval_measures = torch.zeros(11).cuda()
    eval_measures_rain = torch.zeros(11).cuda()
    eval_measures_night = torch.zeros(11).cuda()
    eval_measures_normal = torch.zeros(11).cuda()

    correct = 0
    total = 0
    correct_rain = 0
    total_rain = 0
    correct_night = 0
    total_night = 0
    post_process = True
    with torch.no_grad():
        for i, eval_sample_batched in enumerate(tqdm(dataloader.data)):
            image = eval_sample_batched['image'].to(args.device)
            # text_emb = torch.autograd.Variable(eval_sample_batched['text_emb'].to(args.device))
            text_mask = eval_sample_batched['text_mask'].to(args.device)
            text_feature_general = eval_sample_batched['text_feature_general'].to(args.device)
            text_feature_left = eval_sample_batched['text_feature_left'].to(args.device)
            text_feature_mid_left = eval_sample_batched['text_feature_mid_left'].to(args.device)
            text_feature_mid_right = eval_sample_batched['text_feature_mid_right'].to(args.device)
            text_feature_right = eval_sample_batched['text_feature_right'].to(args.device)
            text_length = eval_sample_batched['text_length'].to(args.device)
            if args.no_radar == False:
                radar_channels = eval_sample_batched['radar_channels'].to(args.device)
                radar_points = eval_sample_batched['radar_points'].to(args.device)


            label = eval_sample_batched['label'].to(args.device)
            gt_depth = eval_sample_batched['depth']

            
            pred_depth, class_pred = model(image, radar_channels, radar_points, text_feature_general, text_feature_left, text_feature_mid_left, \
                                            text_feature_mid_right, text_feature_right, text_mask, text_length)
            if post_process:
                image_flipped = flip_lr(image)
                radar_flipped = flip_lr(radar_channels)

                text_mask = flip_lr(text_mask)
                pred_depth_flipped, _ = model(image_flipped, radar_flipped, radar_points, text_feature_general, text_feature_left, text_feature_mid_left,\
                                        text_feature_mid_right, text_feature_right, text_mask, text_length)
                pred_depth = post_process_depth(pred_depth, pred_depth_flipped)
            _, predicted = class_pred.max(1)
            _, targets = label.max(1)

            total += label.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # rain
            if targets==2:
                total_rain += 1
                correct_rain += (predicted==2).sum().item()

            # night
            if targets==1:
                total_night += 1
                correct_night += (predicted==1).sum().item()


            pred_depth = pred_depth.cpu().numpy().squeeze()
            gt_depth = gt_depth.cpu().numpy().squeeze()

            if args.store_prediction:
                
                pred_d = np.uint32(pred_depth*256.0)
                pred_d = Image.fromarray(pred_d, mode='I')
                pred_d.save(pred_depth_dir + '/' + str(i) + '.png')

            pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
            pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
            pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
            pred_depth[np.isnan(pred_depth)] = args.min_depth_eval

            valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)

            if targets==2:
                measures_rain = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
                eval_measures_rain[:-1] += torch.tensor(measures_rain).to(args.device)
                eval_measures_rain[-1] += 1
            
            elif targets == 1:
                measures_night = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
                eval_measures_night[:-1] += torch.tensor(measures_night).to(args.device)
                eval_measures_night[-1] += 1
            else:
                measures_normal = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
                eval_measures_normal[:-1] += torch.tensor(measures_normal).to(args.device)
                eval_measures_normal[-1] += 1


            measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])

            eval_measures[:-1] += torch.tensor(measures).to(args.device)
            eval_measures[-1] += 1

        print('------------------WHOLE-------------------')
        eval_measures_cpu = eval_measures.cpu()
        cnt = eval_measures_cpu[-1].item()
        eval_measures_cpu /= cnt

        print('Computing errors for {} eval samples'.format(int(cnt)))
        print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'log10', 'abs_rel', 'sq_rel',
                                                                                            'rmse', 'rmse_log', 'd1', 'd2',
                                                                                            'd3', 'mae'))
        for i in range(9):
            print('{:7.3f}, '.format(eval_measures_cpu[i]), end='')
        print('{:7.3f}'.format(eval_measures_cpu[9]))

        print('------------------Normal-------------------')
        eval_measures_cpu_normal = eval_measures_normal.cpu()
        cnt = eval_measures_cpu_normal[-1].item()
        eval_measures_cpu_normal /= cnt

        print('Computing errors for {} eval samples'.format(int(cnt)))
        print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'log10', 'abs_rel', 'sq_rel',
                                                                                            'rmse', 'rmse_log', 'd1', 'd2',
                                                                                            'd3', 'mae'))
        for i in range(9):
            print('{:7.3f}, '.format(eval_measures_cpu_normal[i]), end='')
        print('{:7.3f}'.format(eval_measures_cpu_normal[9]))

        print('-----------------RAIN---------------------')
        eval_measures_cpu_rain = eval_measures_rain.cpu()
        cnt = eval_measures_cpu_rain[-1].item()
        eval_measures_cpu_rain /= cnt

        print('Computing errors for {} eval samples'.format(int(cnt)))
        print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'log10', 'abs_rel', 'sq_rel',
                                                                                            'rmse', 'rmse_log', 'd1', 'd2',
                                                                                            'd3', 'mae'))
        for i in range(9):
            print('{:7.3f}, '.format(eval_measures_cpu_rain[i]), end='')
        print('{:7.3f}'.format(eval_measures_cpu_rain[9]))

        print('------------------NIGHT-------------------')
        eval_measures_cpu_night = eval_measures_night.cpu()
        cnt = eval_measures_cpu_night[-1].item()
        eval_measures_cpu_night /= cnt

        print('Computing errors for {} eval samples'.format(int(cnt)))
        print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'log10', 'abs_rel', 'sq_rel',
                                                                                            'rmse', 'rmse_log', 'd1', 'd2',
                                                                                            'd3', 'mae'))
        for i in range(9):
            print('{:7.3f}, '.format(eval_measures_cpu_night[i]), end='')
        print('{:7.3f}'.format(eval_measures_cpu_night[9]))
    
    return

if __name__ == '__main__':
    args.distributed = False
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    test(args)