import os
import time
import json
import pprint
import random
import numpy as np
from tqdm import tqdm, trange
from collections import defaultdict

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from bm_detr.config import BaseOptions
from bm_detr.start_end_dataset import \
    StartEndDataset, start_end_collate, prepare_batch_inputs
from bm_detr.inference import eval_epoch, start_inference, setup_model
from bm_detr.sampler import build_batch_sampler
from utils.basic_utils import AverageMeter, dict_to_markdown
from utils.model_utils import count_parameters


import logging
logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
                    datefmt="%Y-%m-%d %H:%M:%S",
                    level=logging.INFO)


def set_seed(seed, use_cuda=True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.manual_seed_all(seed)


def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer):
    logger.info(f"[Epoch {epoch_i+1}]")
    model.train()
    criterion.train()

    # init meters
    time_meters = defaultdict(AverageMeter)
    loss_meters = defaultdict(AverageMeter)

    num_training_examples = len(train_loader)
    timer_dataloading = time.time()
    for batch_idx, batch in tqdm(enumerate(train_loader),
                                 desc="Training Iteration",
                                 total=num_training_examples):
        time_meters["dataloading_time"].update(time.time() - timer_dataloading)

        timer_start = time.time()
        vids = [d['vid'] for d in batch[0]]
        if opt.use_random_nq is False:
            assert len(vids) == len(set(vids)), "queries from same the video contain in the batch."

        # else:
        #     print('fuck')
        model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory,
                                                     max_v_l=opt.max_v_l)
        time_meters["prepare_inputs_time"].update(time.time() - timer_start)
        # check the inputs
        # logger.info("model inputs \n{}".format({k: (type(v), v.shape, v.dtype) for k, v in model_inputs.items()}))

        timer_start = time.time()
        outputs = model(**model_inputs)
        # check the outputs
        # logger.info("model outputs \n{}".format({k: (type(v), v.shape) for k, v in outputs.items() if 'aux' not in k}))
        loss_dict, aux_loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        aux_losses = 0
        for additional_loss in opt.additional_losses:
            assert additional_loss in aux_loss_dict, f'check the loss dict has additional loss term, {additional_loss}'
            aux_losses += sum(aux_loss_dict[additional_loss][k] * weight_dict[k] for k in aux_loss_dict[additional_loss].keys() if k in weight_dict)

        if isinstance(opt.additional_losses, list):
            losses = (losses + aux_losses) / (len(opt.additional_losses) + 1)

        time_meters["model_forward_time"].update(time.time() - timer_start)

        timer_start = time.time()
        optimizer.zero_grad()
        losses.backward()
        if opt.grad_clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
        optimizer.step()
        time_meters["model_backward_time"].update(time.time() - timer_start)

        for k, v in loss_dict.items():
            loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))

        for _k, _loss_dict in aux_loss_dict.items():
            for k, v in _loss_dict.items():
                loss_meters[f'{_k}_{k}'].update(float(v) * weight_dict[k] if k in weight_dict else float(v))

        loss_dict["loss_overall"] = float(losses)  # for logging only

        timer_dataloading = time.time()
        if opt.debug and batch_idx == 3:
            break

    # print/add logs
    tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
    _display_dict = {}
    for k, v in loss_meters.items():
        _display_dict[k] = "{:.4f}".format(v.avg)
        tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)

    if opt.debug:
        logger.info("Check the loss terms and values..")
        print(dict_to_markdown(_display_dict))

    to_write = opt.train_log_txt_formatter.format(
        time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
        epoch=epoch_i+1,
        loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
    with open(opt.train_log_filepath, "a") as f:
        f.write(to_write)

    logger.info("Epoch time stats:")
    for name, meter in time_meters.items():
        d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
        logger.info(f"{name} ==> {d}")


def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt, sheet=None):
    if opt.device.type == "cuda":
        logger.info("CUDA enabled.")
        model.to(opt.device)

    tb_writer = SummaryWriter(opt.tensorboard_log_dir)
    tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
    opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
    opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"

    batch_sampler = build_batch_sampler(dataset=train_dataset, bsz=opt.bsz)
    train_loader = DataLoader(
        train_dataset,
        collate_fn=start_end_collate,
        # batch_size=opt.bsz,
        num_workers=opt.num_workers,
        batch_sampler=batch_sampler,
        # shuffle=True,
        pin_memory=opt.pin_memory,
    )

    if opt.use_random_nq:
        train_loader = DataLoader(
            train_dataset,
            collate_fn=start_end_collate,
            batch_size=opt.bsz,
            num_workers=opt.num_workers,
            # batch_sampler=batch_sampler,
            shuffle=True,
            pin_memory=opt.pin_memory,
        )

    prev_best_score = 0.
    es_cnt = 0
    # start_epoch = 0
    if opt.start_epoch is None:
        start_epoch = -1 if opt.eval_untrained else 0
    else:
        start_epoch = opt.start_epoch
    save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)

    for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
        if epoch_i > -1:
            train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer)
            lr_scheduler.step()
        eval_epoch_interval = opt.eval_epoch_interval
        if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
            with torch.no_grad():
                metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
                    eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)

            # log
            to_write = opt.eval_log_txt_formatter.format(
                time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
                epoch=epoch_i,
                loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
                eval_metrics_str=json.dumps(metrics_no_nms))

            with open(opt.eval_log_filepath, "a") as f:
                f.write(to_write)
            logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
            if metrics_nms is not None:
                logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))

            metrics = metrics_no_nms
            for k, v in metrics["brief"].items():
                tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)

            stop_score = metrics["brief"][opt.stop_metric]
            if stop_score >= prev_best_score:
                es_cnt = 0
                prev_best_score = stop_score

                checkpoint = {
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "lr_scheduler": lr_scheduler.state_dict(),
                    "epoch": epoch_i,
                    "opt": opt
                }

                torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))

                best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
                for src, tgt in zip(latest_file_paths, best_file_paths):
                    os.renames(src, tgt)

                # update gspread
                if opt.gspread:
                    if opt.gspread_mode == "methods":
                        sheet.update_cell(7, opt.gspread_col, epoch_i+1)
                        for i, (k, v) in enumerate(metrics_no_nms["brief"].items()):
                            sheet.update_cell(i + 8, opt.gspread_col, v)

                    if opt.gspread_mode == "loss":
                        sheet.update_cell(11, opt.gspread_col, epoch_i+1)
                        for i, (k, v) in enumerate(metrics_no_nms["brief"].items()):
                            sheet.update_cell(i + 12, opt.gspread_col, v)

                logger.info("The checkpoint file has been updated.")

            else:
                es_cnt += 1
                if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt:  # early stop
                    with open(opt.train_log_filepath, "a") as f:
                        f.write(f"Early Stop at epoch {epoch_i}")
                    logger.info(f"\n>>>>> Early stop at epoch {epoch_i}  {prev_best_score}\n")
                    break

            # save ckpt
            checkpoint = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch_i,
                "opt": opt
            }
            torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))

        save_interval = 100
        if (epoch_i + 1) % save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0:  # additional copies
            checkpoint = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "epoch": epoch_i,
                "opt": opt
            }
            torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))

        if opt.debug:
            break

    tb_writer.close()

    # check finished
    if opt.gspread:
        if opt.gspread_mode == "methods":
            _idx = 15 if opt.dset_name == "hl" else 12
            sheet.update_cell(_idx, opt.gspread_col, time.strftime("%Y_%m_%d_%H_%M"))
            sheet.update_cell(_idx+1, opt.gspread_col, 'O')

        if opt.gspread_mode == "loss":
            _idx = 19 if opt.dset_name == "hl" else 16
            sheet.update_cell(_idx, opt.gspread_col, time.strftime("%Y_%m_%d_%H_%M"))
            sheet.update_cell(_idx+1, opt.gspread_col, 'O')


def start_training():
    logger.info("Setup config, data and model...")
    opt = BaseOptions().parse()
    set_seed(opt.seed)

    if opt.debug:  # keep the model run deterministically
        # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
        # Enable this only when input size is fixed.
        cudnn.benchmark = False
        cudnn.deterministic = True

    sheet = None

    if opt.gspread:
        print('Report performances on google spread sheet.')
        import gspread
        gc = gspread.service_account()
        if 'activity' in opt.dset_name:
            g_sheet = gc.open("ActivityNet")

        if 'charades' in opt.dset_name:
            g_sheet = gc.open("Charades")

        if 'hl' == opt.dset_name:
            g_sheet = gc.open("QVHighlights")

        if 'tacos' == opt.dset_name:
            g_sheet = gc.open("Tacos")

        sheet_name = f'{opt.dset_name}_{opt.v_feat_type}_{opt.gspread_mode}'

        try:
            sheet = g_sheet.worksheet(sheet_name)
        except:
            # initialize spread sheet
            sheet = g_sheet.add_worksheet(title=sheet_name, rows=50, cols=50)

            if opt.gspread_mode == "methods":
                sheet.update_cell(2, 1, 'Intra PQ')
                sheet.update_cell(3, 1, 'Intra NQ')
                sheet.update_cell(4, 1, 'Inter NQ')
                sheet.update_cell(5, 1, 'TS')

                sheet.update_cell(7, 1, 'epochs')
                # metric
                if opt.dset_name == 'hl':
                    sheet.update_cell(8, 1, 'MR-R1@0.5')
                    sheet.update_cell(9, 1, 'MR-R1@0.7')
                    sheet.update_cell(10, 1, 'mAP')
                    sheet.update_cell(11, 1, 'mAP@0.5')
                    sheet.update_cell(12, 1, 'mAP@0.75')
                    sheet.update_cell(13, 1, 'VeryGood-MAP')
                    sheet.update_cell(14, 1, 'VeryGood-Hit1')
                else:
                    sheet.update_cell(8, 1, 'MR-R1@0.3')
                    sheet.update_cell(9, 1, 'MR-R1@0.5')
                    sheet.update_cell(10, 1, 'MR-R1@0.7')
                    sheet.update_cell(11, 1, 'mIoU')

            if opt.gspread_mode == "loss":
                sheet.update_cell(2, 1, 'lr')
                sheet.update_cell(3, 1, 'lr_drop')
                sheet.update_cell(4, 1, 'lw_saliency')
                sheet.update_cell(5, 1, 'lw_contrastive_loss_coef')
                sheet.update_cell(6, 1, 'lw_prob_loss_coef')

                sheet.update_cell(11, 1, 'epochs')

                # metric
                if opt.dset_name == 'hl':
                    sheet.update_cell(12, 1, 'MR-R1@0.5')
                    sheet.update_cell(13, 1, 'MR-R1@0.7')
                    sheet.update_cell(14, 1, 'mAP')
                    sheet.update_cell(15, 1, 'mAP@0.5')
                    sheet.update_cell(16, 1, 'mAP@0.75')
                    sheet.update_cell(17, 1, 'VeryGood-MAP')
                    sheet.update_cell(18, 1, 'VeryGood-Hit1')
                else:
                    sheet.update_cell(12, 1, 'MR-R1@0.3')
                    sheet.update_cell(13, 1, 'MR-R1@0.5')
                    sheet.update_cell(14, 1, 'MR-R1@0.7')
                    sheet.update_cell(15, 1, 'mIoU')


        while sheet.cell(1, opt.gspread_col).value:
            opt.gspread_col += 1

        # column name
        sheet.update_cell(1, opt.gspread_col, opt.exp_id + '/' + str(opt.seed))

        if opt.gspread_mode == "methods":
            if opt.use_intra_pq:
                sheet.update_cell(2, opt.gspread_col, '✓')

            if opt.use_intra_nq:
                sheet.update_cell(3, opt.gspread_col, '✓')

            if opt.use_inter_nq:
                sheet.update_cell(4, opt.gspread_col, '✓')

            if opt.use_temporal_shifting:
                sheet.update_cell(5, opt.gspread_col, '✓')

        if opt.gspread_mode == "loss":
            sheet.update_cell(2, opt.gspread_col, opt.lr)
            sheet.update_cell(3, opt.gspread_col, opt.lr_drop)
            sheet.update_cell(4, opt.gspread_col, opt.lw_saliency)
            sheet.update_cell(5, opt.gspread_col, opt.lw_contrastive_loss_coef)
            sheet.update_cell(6, opt.gspread_col, opt.lw_prob_loss_coef)


    dataset_config = dict(
        dset_name=opt.dset_name,
        data_path=opt.train_path,
        meta_by_qid_path_path=opt.meta_by_qid_path_path,
        v_feat_dirs=opt.v_feat_dirs,
        q_feat_dir=opt.t_feat_dir,
        q_feat_type=opt.t_feat_type,
        max_q_l=opt.max_q_l,
        max_v_l=opt.max_v_l,
        ctx_mode=opt.ctx_mode,
        data_ratio=opt.data_ratio,
        normalize_v=not opt.no_norm_vfeat,
        normalize_t=not opt.no_norm_tfeat,
        clip_len=opt.clip_length,
        max_windows=opt.max_windows,
        span_loss_type=opt.span_loss_type,
        txt_drop_ratio=opt.txt_drop_ratio,
        use_intra_pq=opt.use_intra_pq,
        use_intra_nq=opt.use_intra_nq,
        use_temporal_shifting=opt.use_temporal_shifting
    )

    dataset_config["data_path"] = opt.train_path
    train_dataset = StartEndDataset(**dataset_config)

    if opt.eval_path is not None:
        dataset_config["data_path"] = opt.eval_path
        dataset_config["txt_drop_ratio"] = 0
        dataset_config["use_intra_pq"] = False
        dataset_config["use_intra_nq"] = False
        dataset_config["use_temporal_shifting"] = False
        # dataset_config["load_labels"] = False  # uncomment to calculate eval loss
        eval_dataset = StartEndDataset(**dataset_config)
    else:
        eval_dataset = None

    if opt.lr_warmup > 0:
        # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
        total_steps = opt.n_epoch
        warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
        opt.lr_warmup = [warmup_steps, total_steps]

    model, criterion, optimizer, lr_scheduler = setup_model(opt)
    # logger.info(f"Model {model}")
    count_parameters(model)
    logger.info("Start Training...")
    # for n, p in model.named_parameters():
    #     print(f"{n} is trainable...")
    # for name, sub_module in model.named_modules():
    #     print(name, sub_module)

    train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt, sheet)

    return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug


if __name__ == '__main__':
    best_ckpt_path, eval_split_name, eval_path, debug = start_training()
    if not debug:
        input_args = ["--resume", best_ckpt_path,
                      "--eval_split_name", eval_split_name,
                      "--eval_path", eval_path]

        import sys
        sys.argv[1:] = input_args
        logger.info("\n\n\nFINISHED TRAINING!!!")
        logger.info("Evaluating model at {}".format(best_ckpt_path))
        logger.info("Input args {}".format(sys.argv[1:]))
        start_inference()
