import json
import os
import sys
from argparse import ArgumentParser
from shutil import copyfile

import torch
import torch.nn.functional as F
from loguru import logger
from PIL import Image, ImageDraw
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.models.bert.tokenization_bert import BertTokenizer
from models.tv_interaction_models import GPT2Captions, RobertaCaptions, BertCaptions
from config import get_cfg_defaults
from transformers import (
    BertForSequenceClassification,
    GPT2ForSequenceClassification,
    GPT2Tokenizer,
    GPT2Config,
    BertConfig,
    RobertaTokenizer,
    RobertaForSequenceClassification,
    RobertaConfig,
)
from transformers.optimization import AdamW
from utils.dataloaders.tv_interaction_dataloader import TVInteractionDataset
from utils.helper_utils import (
    WarmupLinearScheduleNonZero,
    batch_iter,
    create_folders_if_necessary,
    get_dir,
    prune_illegal_collate,
    seed,
)
from utils.table_visualizer import TableVisualizer
from sklearn.model_selection import KFold
import numpy as np


def evaluate(
    cfg,
    model,
    dataset,
    split="trainval",
    test_ids=None,
    viz_path=None,
):
    table_configs = []
    table_configs.append(
        {
            "id": "idx",
            "display_name": "idx",
            "type": "text",
            "sortable": True,
            "width": "5%",
        }
    )

    table_configs.append(
        {
            "id": "Captions",
            "display_name": "Captions",
            "type": "text",
            "sortable": True,
            "width": "20%",
        }
    )

    table_configs.append(
        {
            "id": "Images",
            "display_name": "Images",
            "type": "image",
            "height": 200,
            "width": "30%",
        }
    )

    table_configs.append(
        {
            "id": "GT",
            "display_name": "GT Class",
            "type": "text",
            "sortable": True,
            "width": "10%",
        }
    )

    table_configs.append(
        {
            "id": "Scores",
            "display_name": "Scores",
            "type": "text",
            "sortable": True,
            "width": "10%",
        }
    )
    table_viz = TableVisualizer(
        table_configs,
        "assets/table_visualizer_style.css",
        os.path.join(viz_path, "index.html"),
    )

    model.eval()
    old_split = dataset.split
    dataset.split = split
    device = torch.device(
        "cuda" if torch.cuda.is_available() and cfg["SYSTEM"]["NUM_GPUS"] > 0 else "cpu"
    )
    # Accuracy, R@3, Mean Rank
    metrics = {}
    vbs = cfg["TV"]["BATCH_PER_GPU"] * cfg["SYSTEM"]["NUM_GPUS"]
    sampler = torch.utils.data.SubsetRandomSampler(test_ids)

    val_dataloader = DataLoader(
        dataset,
        batch_size=vbs,
        num_workers=cfg["SYSTEM"]["NUM_WORKERS"],
        drop_last=False,
        pin_memory=False,
        collate_fn=prune_illegal_collate,
        sampler=sampler,
    )
    viz_samples_counter = 0
    r_1_accumulate = 0
    tot_samples = 0
    with torch.no_grad():
        for _, _, batch in tqdm(batch_iter(val_dataloader, 1)):
            for k, v in batch.items():
                batch[k] = v.to(device)

            captions = batch["captions"]
            captions_length = batch["captions_length"]

            target_labels = batch["target"]
            indices = batch["index"]
            image_ids = batch["image_ids"]
            logits = model(batch)
            normalized_logits = F.softmax(logits, dim=1)

            r_1_accumulate += torch.sum(
                torch.argmax(normalized_logits, dim=1) == target_labels
            )
            # iterate through all elements
            color = "rgb(255, 255, 255)"  # white color

            for j in range(normalized_logits.shape[0]):
                row = []
                row.append(str(viz_samples_counter))
                #  all decode the captions
                decoded_input_caption = dataset.decode(captions[j], captions_length[j])
                row.append(decoded_input_caption)
                # create gif
                cur_image_ids = image_ids[j].tolist()
                paths = dataset.get_image_paths(indices[j].item(), cur_image_ids)
                loaded_images = []
                for p_id, p in enumerate(paths):
                    im = Image.open(p)
                    draw = ImageDraw.Draw(im)
                    draw.text((10, 10), str(p_id), fill=color)
                    loaded_images.append(im)

                img, *imgs = loaded_images
                out_path = os.path.join(viz_path, "%d.gif" % viz_samples_counter)
                img.save(
                    fp=out_path,
                    format="GIF",
                    append_images=imgs,
                    save_all=True,
                    duration=1000,
                    loop=0,
                )
                gt_class = dataset.classes[target_labels[j].item()]
                scores = normalized_logits[j].tolist()
                row.append("%d.gif" % viz_samples_counter)
                row.append(gt_class)
                row.append(
                    [
                        dataset.classes[class_id] + "(%0.3f)" % score
                        for class_id, score in enumerate(scores)
                    ]
                )
                table_viz.add_row(row)
                viz_samples_counter += 1
                tot_samples += 1

    metrics["r@1"] = (r_1_accumulate / tot_samples).item()
    model.train()
    dataset.split = old_split

    table_viz.render()

    return metrics


def train(
    cfg,
    model,
    dataset,
    train_split="trainval",
    train_ids=None,
    test_ids=None,
    fold_id=None,
    evaluate_during_training=True,
):
    model.train()
    old_split = dataset.split
    dataset.split = train_split

    device = torch.device(
        "cuda" if torch.cuda.is_available() and cfg["SYSTEM"]["NUM_GPUS"] > 0 else "cpu"
    )
    loss_fct = CrossEntropyLoss()
    bs = cfg["TV"]["BATCH_PER_GPU"] * cfg["SYSTEM"]["NUM_GPUS"]
    sampler = torch.utils.data.SubsetRandomSampler(train_ids)

    dataloader = DataLoader(
        dataset,
        batch_size=bs,
        num_workers=cfg["SYSTEM"]["NUM_WORKERS"],
        drop_last=False,
        pin_memory=False,
        collate_fn=prune_illegal_collate,
        sampler=sampler,
    )
    create_folders_if_necessary(
        os.path.join(save_path, "saved-checkpoints", "checkpoint_logs.txt")
    )
    num_iter_per_epoch = dataset.num_data_points_per_split["train"] // bs

    optim_parameters = model.parameters()
    optimizer = AdamW(optim_parameters, lr=cfg["TV"]["LR"])
    scheduler = WarmupLinearScheduleNonZero(
        optimizer,
        warmup_steps=5,
        t_total=cfg["TV"]["NUM_EPOCHS"] * num_iter_per_epoch,
    )

    optimizer.zero_grad()
    # evaluate before training too
    best_acc = 0
    if evaluate_during_training:

        viz_path = os.path.join(viz_path_root, "[fold_%d]_init-checkpoint" % fold_id)
        create_folders_if_necessary(os.path.join(viz_path, "trainval", "viz_path.txt"))
        copyfile(
            "assets/list.min.js", os.path.join(viz_path, "trainval", "list.min.js")
        )
        logger.info("Evaluating ...")
        metrics_trainval = evaluate(
            cfg,
            model,
            dataset,
            split="trainval",
            test_ids=test_ids,
            viz_path=os.path.join(viz_path, "trainval"),
        )
        logger.info(
            "[fold_%d] Init Checkpoint, metrics trainval: %s"
            % (fold_id, json.dumps(metrics_trainval))
        )
        logger_eval.info(
            "[fold_%d] Init Checkpoint, metrics trainval: %s"
            % (fold_id, json.dumps(metrics_trainval))
        )

    for epoch in range(cfg["TV"]["NUM_EPOCHS"]):
        for _, iter_id, batch in batch_iter(dataloader, 1):

            for k, v in batch.items():
                batch[k] = v.to(device)

            logits = model(batch)
            target_labels = batch["target"]
            loss = loss_fct(logits, target_labels)

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            if iter_id % cfg["TV"]["LOG_EVERY"] == 0:
                logger.info(
                    "[fold:%d] iter id: %d loss: %f" % (fold_id, iter_id, loss.item())
                )

        if (
            evaluate_during_training
            and epoch % cfg["TV"]["EVALUATE_EVERY_N_EPOCHS"] == 0
        ):
            idx = epoch * num_iter_per_epoch

            viz_path = os.path.join(viz_path_root, "[fold_%d]iter_%d" % (fold_id, idx))
            create_folders_if_necessary(
                os.path.join(viz_path, "trainval", "viz_path.txt")
            )

            copyfile(
                "assets/list.min.js", os.path.join(viz_path, "trainval", "list.min.js")
            )
            logger.info("Evaluating ...")

            metrics_trainval = evaluate(
                cfg,
                model,
                dataset,
                split="trainval",
                test_ids=test_ids,
                viz_path=os.path.join(viz_path, "trainval"),
            )
            if metrics_trainval["r@1"] > best_acc:
                best_acc = metrics_trainval["r@1"]
            logger.info(
                "[fold:%d]epoch: %d, metrics trainval: %s"
                % (fold_id, epoch, json.dumps(metrics_trainval))
            )
            logger_eval.info(
                "[fold:%d]epoch: %d, metrics trainval: %s"
                % (fold_id, epoch, json.dumps(metrics_trainval))
            )
            logger.info("best acc: %0.4f" % best_acc)
            logger_eval.info("best acc: %0.4f" % best_acc)

    dataset.split = old_split
    return best_acc


def read_command_line(argv=None):
    parser = ArgumentParser(description="Future Prediction using Language")
    parser.add_argument(
        "--config_path", help="Path to yaml config file", default="configs/debug.yaml"
    )
    try:
        params = vars(parser.parse_args(args=argv))
    except IOError as msg:
        parser.error(str(msg))
    return params


if __name__ == "__main__":
    # get config file
    cmd_line = read_command_line()
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cmd_line["config_path"])
    cfg.freeze()

    # setup logging

    # create directories for logging and visualizations
    os.makedirs("checkpoints", exist_ok=True)
    save_path = get_dir(cfg["SYSTEM"]["OUT_DIR"], cfg["SYSTEM"]["SEED"], "checkpoints")

    create_folders_if_necessary(os.path.join(save_path, "experiment.yaml"))
    with open(os.path.join(save_path, "experiment.yaml"), "w") as f:
        f.write(cfg.dump())

    viz_path_root = os.path.join(save_path, "viz")

    logger.add(
        os.path.join(save_path, "file_{time}.log"),
        level=cfg["SYSTEM"]["DEBUGGING_LEVEL"],
    )
    logger.add(
        os.path.join(save_path, "eval_{time}.log"),
        level=cfg["SYSTEM"]["DEBUGGING_LEVEL"],
        filter=lambda record: record["extra"]["task"] == "eval",
    )
    logger.add(
        sys.stderr,
        format="{time} {level} {message}",
        level=cfg["SYSTEM"]["DEBUGGING_LEVEL"],
    )

    logger_eval = logger.bind(task="eval")
    logger = logger.bind(task="not-eval")

    logger.info("Save Path: %s" % save_path)

    seed(cfg["SYSTEM"]["SEED"])

    device = torch.device(
        "cuda" if torch.cuda.is_available() and cfg["SYSTEM"]["NUM_GPUS"] > 0 else "cpu"
    )

    K = cfg["TV"]["KFOLDS"]
    kfold = KFold(n_splits=K, shuffle=True)

    # initialize tokenizer
    tokenizer = None
    if cfg["TV"]["MODEL_TYPE"] == "GPT2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    elif cfg["TV"]["MODEL_TYPE"] == "ROBERTA":
        tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    elif cfg["TV"]["MODEL_TYPE"] == "BERT":
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    else:
        raise NotImplementedError()

    tv_dataset = TVInteractionDataset(cfg["TV"], tokenizer)
    dataset_shuffle = True
    tv_dataset.split = "trainval"

    metrics = {}
    for fold, (train_ids, test_ids) in enumerate(kfold.split(tv_dataset)):
        print("test ids", test_ids)
        # initialize model
        model = None
        if cfg["TV"]["MODEL_TYPE"] == "GPT2":
            if cfg["TV"]["MODEL_PATH"]:
                backbone = GPT2ForSequenceClassification.from_pretrained(
                    cfg["TV"]["MODEL_PATH"], num_labels=4
                )
            else:
                logger.warning("Random model initialized")
                backbone_config = GPT2Config(num_labels=4)
                backbone = GPT2ForSequenceClassification(backbone_config)
            backbone.config.pad_token_id = backbone.config.eos_token_id
            model = GPT2Captions(backbone, tokenizer, cfg)

        elif cfg["TV"]["MODEL_TYPE"] == "BERT":
            if cfg["TV"]["MODEL_PATH"]:
                backbone_bert_model = BertForSequenceClassification.from_pretrained(
                    cfg["TV"]["MODEL_PATH"],
                    hidden_dropout_prob=cfg["TV"]["DROPOUT"],
                    attention_probs_dropout_prob=cfg["TV"]["DROPOUT"],
                    num_labels=4,
                )
            else:
                logger.warning("Random model initialized")
                backbone_config = BertConfig(
                    hidden_dropout_prob=cfg["TV"]["DROPOUT"],
                    attention_probs_dropout_prob=cfg["TV"]["DROPOUT"],
                    num_labels=4,
                )
                backbone_bert_model = BertForSequenceClassification(backbone_config)

            model = BertCaptions(backbone_bert_model, tokenizer, cfg)

        elif cfg["TV"]["MODEL_TYPE"] == "ROBERTA":
            if cfg["TV"]["MODEL_PATH"]:
                backbone_bert_model = RobertaForSequenceClassification.from_pretrained(
                    cfg["TV"]["MODEL_PATH"],
                    hidden_dropout_prob=cfg["TV"]["DROPOUT"],
                    attention_probs_dropout_prob=cfg["TV"]["DROPOUT"],
                    num_labels=4,
                )
            else:
                logger.warning("Random model initialized")
                backbone_config = RobertaConfig(num_labels=4)
                backbone_bert_model = RobertaForSequenceClassification(backbone_config)

            model = RobertaCaptions(backbone_bert_model, tokenizer, cfg)

        else:
            raise NotImplementedError()

        # freeze params
        # for n, p in model.named_parameters():
        #     if "layer" in n:
        #         p.requires_grad = False

        model = model.to(device)

        best_acc = train(
            cfg,
            model,
            tv_dataset,
            train_split="trainval",
            train_ids=train_ids,
            test_ids=test_ids,
            fold_id=fold,
        )
        metrics[fold] = best_acc

    best_acc_vec = [acc for _, acc in metrics.items()]
    best_acc_mean = np.mean(best_acc_vec)
    best_acc_std = np.std(best_acc_vec)
    logger.info("accuracies across splits: %s" % json.dumps(metrics))
    logger.info("best accuracy mean: %0.4f +- %0.4f" % (best_acc_mean, best_acc_std))
    logger_eval.info("accuracies across splits: %s" % json.dumps(metrics))
    logger_eval.info(
        "best accuracy mean: %0.4f +- %0.4f" % (best_acc_mean, best_acc_std)
    )
