#!/usr/bin/env python
# coding: utf-8

import torch
import tqdm
import time
import warnings
import os.path as osp
import torch.nn as nn
from torch import optim
from termcolor import colored

from referit3d.utils.tf_visualizer import Visualizer  # should be imported first (https://github.com/pytorch/pytorch/issues/30651)
from referit3d.in_out.arguments import parse_arguments
from referit3d.in_out.neural_net_oriented import load_scan_related_data, load_referential_data
from referit3d.in_out.neural_net_oriented import compute_auxiliary_data, trim_scans_per_referit3d_data
from referit3d.in_out.pt_datasets.listening_dataset import make_data_loaders
from referit3d.in_out.pt_datasets.utils import create_sr3d_classes_2_idx
from referit3d.utils import set_gpu_to_zero_position, create_logger, seed_training_code

from referit3d.models.referit3d_net import ReferIt3DNet_transformer
from referit3d.models.referit3d_net_utils import single_epoch_train, evaluate_on_dataset
from referit3d.models.utils import load_state_dicts, save_state_dicts
from referit3d.analysis.deepnet_predictions import analyze_predictions
from transformers import BertTokenizer, BertModel


def log_train_test_information():
    """Helper logging function.
    Note uses "global" variables defined below.
    """
    logger.info('Epoch:{}'.format(epoch))
    for phase in ['train', 'test']:
        if phase == 'train':
            meters = train_meters
        else:
            meters = test_meters

        info = '{}: Total-Loss {:.4f}, Listening-Acc {:.4f}'.format(phase,
                                                                    meters[phase + '_total_loss'],
                                                                    meters[phase + '_referential_acc'])

        if args.obj_cls_alpha > 0:
            info += ', Object-Clf-Acc: {:.4f}'.format(meters[phase + '_object_cls_acc'])
            info += ', Target-Clf-Acc: {:.4f}'.format(meters[phase + '_target_cls_acc'])

        if args.lang_cls_alpha > 0:
            info += ', Text-Clf-Acc: {:.4f}'.format(meters[phase + '_txt_cls_acc'])

        logger.info(info)
        logger.info('{}: Epoch-time {:.3f}'.format(phase, timings[phase]))
    logger.info('Best so far {:.3f} (@epoch {})'.format(best_test_acc, best_test_epoch))


if __name__ == '__main__':
    # Parse arguments
    args = parse_arguments()
    # Read the scan related information
    all_scans_in_dict, scans_split, class_to_idx = load_scan_related_data(args.scannet_file,
                                                                           add_no_obj=args.anchors != 'none' or args.predict_lang_anchors)
    is_nr = True if 'nr' in args.referit3D_file else False
    if is_nr:
        class_to_idx = create_sr3d_classes_2_idx(json_pth="referit3d/data/mappings/scannet_instance_class_to_semantic_class.json")
    # Read the linguistic data of ReferIt3D
    referit_data = load_referential_data(args, args.referit3D_file, scans_split)
    # Prepare data & compute auxiliary meta-information.
    all_scans_in_dict = trim_scans_per_referit3d_data(referit_data, all_scans_in_dict)
    mean_rgb, vocab = compute_auxiliary_data(referit_data, all_scans_in_dict, args)
    data_loaders = make_data_loaders(args, referit_data, vocab, class_to_idx, all_scans_in_dict, mean_rgb)
    # Prepare GPU environment
    set_gpu_to_zero_position(args.gpu)

    device = torch.device('cuda')
    seed_training_code(args.random_seed)

    # Losses:
    criteria = dict()
    # Prepare the Listener
    n_classes = len(class_to_idx) - 1  # -1 to ignore the <pad> class
    pad_idx = class_to_idx['pad']
    # Object-type classification
    class_name_list = []
    for cate in class_to_idx:
        class_name_list.append(cate)

    tokenizer = BertTokenizer.from_pretrained(args.bert_pretrain_path)
    class_name_tokens = tokenizer(class_name_list, return_tensors='pt', padding=True)
    for name in class_name_tokens.data:
        class_name_tokens.data[name] = class_name_tokens.data[name].cuda()

    gpu_num = len(args.gpu.strip(',').split(','))

    if args.model == 'referIt3DNet_transformer':
        model = ReferIt3DNet_transformer(args, n_classes, class_name_tokens, ignore_index=pad_idx, class_to_idx=class_to_idx)
    else:
        assert False

    if gpu_num > 1:
        model = nn.DataParallel(model)

    model = model.to(device)
    print(model)

    # <1>
    if gpu_num > 1:
        param_list = [
            {'params': model.module.language_encoder.parameters(), 'lr': args.init_lr * 0.1},
            {'params': model.module.refer_encoder.parameters(), 'lr': args.init_lr * 0.1},
            {'params': model.module.object_encoder.parameters(), 'lr': args.init_lr},
            {'params': model.module.obj_feature_mapping.parameters(), 'lr': args.init_lr},
            {'params': model.module.box_feature_mapping.parameters(), 'lr': args.init_lr},
            {'params': model.module.language_clf.parameters(), 'lr': args.init_lr},
            {'params': model.module.object_language_clf.parameters(), 'lr': args.init_lr},
        ]
        if not args.label_lang_sup:
            param_list.append({'params': model.module.obj_clf.parameters(), 'lr': args.init_lr})
    else:
        param_list = [
            {'params': model.language_encoder.parameters(), 'lr': args.init_lr * 0.1},
            {'params': model.refer_encoder.parameters(), 'lr': args.init_lr * 0.1},
            {'params': model.object_encoder.parameters(), 'lr': args.init_lr},
            {'params': model.obj_feature_mapping.parameters(), 'lr': args.init_lr},
            {'params': model.box_feature_mapping.parameters(), 'lr': args.init_lr},
            {'params': model.language_clf.parameters(), 'lr': args.init_lr},
            {'params': model.object_language_clf.parameters(), 'lr': args.init_lr*0.1},
        ]
        if not args.label_lang_sup:
            param_list.append({'params': model.obj_clf.parameters(), 'lr': args.init_lr})
        if args.anchors == 'cot':
            param_list.append({'params': model.parallel_embedding.parameters(), 'lr': args.init_lr})
            param_list.append({'params': model.object_language_clf_parallel.parameters(), 'lr': args.init_lr})
            param_list.append({'params': model.fc_out.parameters(), 'lr': args.init_lr})

    optimizer = optim.Adam(param_list, lr=args.init_lr)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [40, 50, 60, 70, 80, 90], gamma=0.65)

    start_training_epoch = 1
    best_test_acc = -1
    best_test_epoch = -1
    last_test_acc = -1
    last_test_epoch = -1

    if args.resume_path:
        warnings.warn('Resuming assumes that the BEST per-val model is loaded!')
        # perhaps best_test_acc, best_test_epoch, best_test_epoch =  unpickle...
        loaded_epoch = load_state_dicts(args.resume_path, map_location=device, model=model)
        print('Loaded a model stopped at epoch: {}.'.format(loaded_epoch))
        if not args.fine_tune:
            print('Loaded a model that we do NOT plan to fine-tune.')
            load_state_dicts(args.resume_path, optimizer=optimizer, lr_scheduler=lr_scheduler)
            start_training_epoch = loaded_epoch + 1
            start_training_epoch = 0
            best_test_epoch = loaded_epoch
            best_test_acc = 0
            print('Loaded model had {} test-accuracy in the corresponding dataset used when trained.'.format(
                best_test_acc))
        else:
            print('Parameters that do not allow gradients to be back-propped:')
            ft_everything = True
            for name, param in model.named_parameters():
                if not param.requires_grad:
                    print(name)
                    exist = False
            if ft_everything:
                print('None, all wil be fine-tuned')
            # if you fine-tune the previous epochs/accuracy are irrelevant.
            dummy = args.max_train_epochs + 1 - start_training_epoch
            print('Ready to *fine-tune* the model for a max of {} epochs'.format(dummy))

    # Training.
    if args.mode == 'train':
        train_vis = Visualizer(args.tensorboard_dir)
        logger = create_logger(args.log_dir)
        logger.info('Starting the training. Good luck!')

        with tqdm.trange(start_training_epoch, args.max_train_epochs + 1, desc='epochs') as bar:
            timings = dict()
            for epoch in bar:
                print("cnt_lr", lr_scheduler.get_last_lr())
                # Train:
                tic = time.time()
                train_meters = single_epoch_train(model, data_loaders['train'], criteria, optimizer,
                                                  device, pad_idx, args=args, tokenizer=tokenizer, epoch=epoch)
                toc = time.time()
                timings['train'] = (toc - tic) / 60

                # Evaluate:
                tic = time.time()
                test_meters = evaluate_on_dataset(model, data_loaders['test'], criteria, device, pad_idx, args=args,
                                                  tokenizer=tokenizer, epoch=epoch)
                toc = time.time()
                timings['test'] = (toc - tic) / 60

                eval_acc = test_meters['test_referential_acc']

                last_test_acc = eval_acc
                last_test_epoch = epoch

                lr_scheduler.step()

                save_state_dicts(osp.join(args.checkpoint_dir, 'last_model.pth'),
                                 epoch, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)

                if best_test_acc < eval_acc:
                    logger.info(colored('Test accuracy, improved @epoch {}'.format(epoch), 'green'))
                    best_test_acc = eval_acc
                    best_test_epoch = epoch

                    save_state_dicts(osp.join(args.checkpoint_dir, 'best_model.pth'),
                                     epoch, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)
                else:
                    logger.info(colored('Test accuracy, did not improve @epoch {}'.format(epoch), 'red'))

                log_train_test_information()
                train_meters.update(test_meters)
                train_vis.log_scalars({k: v for k, v in train_meters.items() if '_acc' in k}, step=epoch,
                                      main_tag='acc')
                train_vis.log_scalars({k: v for k, v in train_meters.items() if '_loss' in k},
                                      step=epoch, main_tag='loss')

                bar.refresh()

        with open(osp.join(args.checkpoint_dir, 'final_result.txt'), 'w') as f_out:
            f_out.write(('Best accuracy: {:.4f} (@epoch {})'.format(best_test_acc, best_test_epoch)))
            f_out.write(('Last accuracy: {:.4f} (@epoch {})'.format(last_test_acc, last_test_epoch)))

        logger.info('Finished training successfully.')

    elif args.mode == 'evaluate':

        meters = evaluate_on_dataset(model, data_loaders['test'], criteria, device, pad_idx, args=args,
                                     tokenizer=tokenizer)
        print('Reference-Accuracy: {:.4f}'.format(meters['test_referential_acc']))
        print('Object-Clf-Accuracy: {:.4f}'.format(meters['test_object_cls_acc']))
        print('Text-Clf-Accuracy {:.4f}:'.format(meters['test_txt_cls_acc']))

        out_file = osp.join(args.checkpoint_dir, 'test_result.txt')
        res = analyze_predictions(model, data_loaders['test'].dataset, class_to_idx, pad_idx, device,
                                  args, out_file=out_file, tokenizer=tokenizer)
        print(res)
