import copy
import json
import os
import sys
from argparse import ArgumentParser
from shutil import copyfile

import torch
import torch.nn.functional as F
from loguru import logger
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from tqdm import tqdm

import transformers
from config import get_cfg_defaults
from models.vist_models import LSTM, NearestNeighborResnet, BertCaptions, GPT2Captions, RobertaCaptions
from transformers import (
    BertForNextSentencePrediction,
    BertTokenizer,
    GPT2ForSequenceClassification,
    GPT2Tokenizer,
    GPT2Config,
    BertConfig,
    RobertaTokenizer,
    RobertaForSequenceClassification,
    RobertaConfig
)
from transformers.optimization import AdamW
from utils.dataloaders.vist_dataloader import VISTDataset
from utils.helper_utils import (
    WarmupLinearScheduleNonZero,
    batch_iter,
    create_folders_if_necessary,
    get_dir,
    prune_illegal_collate,
    seed,
    sequence_mask,
)
from utils.table_visualizer import TableVisualizer

def evaluate(
    cfg,
    model,
    dataset,
    split="val",
    visualize=False,
    viz_path=None,
    num_samples_visualize=700,
):
    if visualize:
        table_configs = []
        table_configs.append(
            {
                "id": "idx",
                "display_name": "idx",
                "type": "text",
                "sortable": True,
                "width": "1%",
            }
        )

        table_configs.append(
            {
                "id": "Scores",
                "display_name": "Scores",
                "type": "text",
                "sortable": False,
                "width": "1%",
            }
        )

        table_configs.append(
            {
                "id": "Input Caption",
                "display_name": "Input Caption",
                "type": "text",
                "sortable": True,
                "width": "3%",
            }
        )
        table_configs.append(
            {
                "id": "Candidate Captions",
                "display_name": "Candidate Captions",
                "type": "text",
                "sortable": True,
                "width": "12%",
            }
        )

        table_configs.append(
            {
                "id": "InputImage",
                "display_name": "Input Image",
                "type": "image",
                "height": 200,
                "width": "13%",
            }
        )

        table_configs.append(
            {
                "id": "GT Target",
                "display_name": "GT Target",
                "type": "image",
                "height": 200,
                "width": "13%",
            }
        )

        table_configs.append(
            {
                "id": "Distractor1",
                "display_name": "Distractor-1",
                "type": "image",
                "height": 200,
                "width": "13%",
            }
        )

        table_configs.append(
            {
                "id": "Distractor2",
                "display_name": "Distractor-2",
                "type": "image",
                "height": 200,
                "width": "13%",
            }
        )
        table_configs.append(
            {
                "id": "Distractor3",
                "display_name": "Distractor-3",
                "type": "image",
                "height": 200,
                "width": "13%",
            }
        )
        table_configs.append(
            {
                "id": "Distractor4",
                "display_name": "Distractor-4",
                "type": "image",
                "height": 200,
                "width": "13%",
            }
        )

        table_viz = TableVisualizer(
            table_configs,
            "assets/table_visualizer_style.css",
            os.path.join(viz_path, "index.html"),
        )

    model.eval()
    old_split = dataset.split
    dataset.split = split
    device = torch.device(
        "cuda" if torch.cuda.is_available() and cfg["SYSTEM"]["NUM_GPUS"] > 0 else "cpu"
    )
    # Accuracy, R@3, Mean Rank
    metrics = {}
    vbs = cfg["VIST"]["BATCH_PER_GPU"] * cfg["SYSTEM"]["NUM_GPUS"]
    val_dataloader = DataLoader(
        dataset,
        batch_size=vbs,
        shuffle=False,
        num_workers=cfg["SYSTEM"]["NUM_WORKERS"],
        drop_last=False,
        pin_memory=False,
        collate_fn=prune_illegal_collate,
    )

    r_1_accumulate = 0
    r_3_accumulate = 0
    tot_samples = 0
    viz_samples_counter = 0
    scores_log = []
    with torch.no_grad():
        for _, _, batch in tqdm(batch_iter(val_dataloader, 1)):
            for k, v in batch.items():
                batch[k] = v.to(device)

            image_ids = batch["image_id"]
            distractor_image_ids = batch["distractor_image_ids"]
            input_image_ids = image_ids[:, -2]
            target_image_ids = image_ids[:, -1]
            candidate_image_ids = torch.cat(
                [target_image_ids.unsqueeze(1), distractor_image_ids], dim=1
            )

            captions = batch["captions"]
            captions_length = batch["captions_length"]
            distractor_captions = batch["distractor_captions"]
            distractor_captions_length = batch["distractor_captions_length"]
            target_captions = captions[:, -1, :]  # batch x seqlen; batch x 4 x seqeln
            target_captions_length = captions_length[:, -1]
            # concatenate target captions with distractor captions
            candidates = torch.cat(
                [target_captions.unsqueeze(1), distractor_captions], dim=1
            )
            candidates_length = torch.cat(
                [target_captions_length.unsqueeze(1), distractor_captions_length],
                dim=1,
            )
            
            input_captions = captions[:, :-1, :]
            input_captions_length = captions_length[:, :-1]
            num_captions_history = input_captions.shape[-2]
            input_captions = input_captions.view(input_captions.shape[0], -1)
            input_captions_length = torch.sum(input_captions_length, dim=1)
            # concatenate all the first 4 captions respecting the lengths of the individual captions
            max_len = torch.max(input_captions_length)
            # concatenate the input captions with the distractors
            input_captions_concat = (
                torch.zeros(input_captions_length.shape[0], max_len).long().to(device)
            )
            sequence_masks = []
            for i in range(num_captions_history):
                sequence_masks.append(
                    sequence_mask(
                        captions_length[:, i],
                        max_len=cfg["VIST"]["MAX_LEN_CAPTION"],
                    )
                )
            sequence_masks = torch.cat(sequence_masks, dim=1)
            assert torch.equal(torch.sum(sequence_masks, dim=1), input_captions_length)

            input_captions_concat[
                sequence_mask(input_captions_length, max_len=max_len)
            ] = input_captions[sequence_masks]
            input_captions = input_captions_concat
            

            if cfg["VIST"]["MODEL_TYPE"] == "GPT2":
                logits = model(batch)
                probs = F.softmax(logits, dim=1)[:, 0]
                loss = probs.view(-1, cfg["VIST"]["NUM_DISTRACTORS"] + 1)
                _, sorted_indices = torch.sort(-loss)    
            elif (
                cfg["VIST"]["MODEL_TYPE"] == "BERT"):
                logits = model(batch)
                probs = F.softmax(logits, dim=1)[:, 0]
                loss = probs.view(-1, cfg["VIST"]["NUM_DISTRACTORS"] + 1)
                _, sorted_indices = torch.sort(-loss)

            elif (
                cfg["VIST"]["MODEL_TYPE"] == "ROBERTA"):
                logits = model(batch)
                probs = F.softmax(logits, dim=1)[:, 0]
                loss = probs.view(-1, cfg["VIST"]["NUM_DISTRACTORS"] + 1)
                _, sorted_indices = torch.sort(-loss)

            elif cfg["VIST"]["MODEL_TYPE"] == "NearestNeighbor":
                feat_diff = model(batch)
                probs = F.softmax(feat_diff, dim=1)
                loss = probs.view(-1, cfg["VIST"]["NUM_DISTRACTORS"] + 1)
                _, sorted_indices = torch.sort(-loss)

            elif cfg["VIST"]["MODEL_TYPE"] == "LSTM":
                _, feat_diff = model(batch)
                probs = F.softmax(feat_diff, dim=1)
                loss = probs.view(-1, cfg["VIST"]["NUM_DISTRACTORS"] + 1)
                _, sorted_indices = torch.sort(-loss)
            else:
                raise NotImplementedError()

            for j in range(loss.shape[0]):
                cur_sample_log = {}
                cur_candidates = candidate_image_ids[j].tolist()
                cur_sample_log["context_image_ids"] = image_ids[j, :-1].tolist() + [cur_candidates[0]]
                cur_sample_log["candidates"] = cur_candidates[1:]
                cur_sample_log["scores"] = loss[j].tolist()
                scores_log.append(cur_sample_log)

            # finding metrics
            # find how many time index 0 (GT target caption) has the lowest loss.
            r_1_accumulate += torch.sum(sorted_indices[:, 0] == 0)
            r_3_accumulate += torch.sum(sorted_indices[:, :3] == 0)
            tot_samples += loss.shape[0]

            # add some visualizations for the first 5 batches

            if viz_samples_counter < num_samples_visualize and visualize:
                # iterate through all elements
                for j in range(loss.shape[0]):
                    row = []
                    if viz_samples_counter < num_samples_visualize:
                        row.append(str(viz_samples_counter))
                        row.append(loss[j].tolist())
                        #  all decode the captions
                        decoded_input_caption = dataset.decode(
                            input_captions[j : j + 1], input_captions_length[j : j + 1]
                        )
                        decoded_captions_candidates = dataset.decode(
                            candidates[j], candidates_length[j]
                        )

                        if cfg["VIST"]["MODEL_TYPE"] == "NearestNeighbor":
                            row.append("")
                            row.append("")
                        else:
                            row.append(decoded_input_caption)
                            row.append(decoded_captions_candidates)
                        # add links to images
                        input_image_url = dataset.get_image_urls(
                            [input_image_ids[j].item()]
                        )[0]
                        candidate_urls = dataset.get_image_urls(
                            candidate_image_ids[j].tolist()
                        )
                        row.append(input_image_url)
                        [row.append(u) for u in candidate_urls]
                        table_viz.add_row(row)
                        viz_samples_counter += 1


    metrics["r1"] = (r_1_accumulate / tot_samples).item()
    metrics["r3"] = (r_3_accumulate / tot_samples).item()
    model.train()
    dataset.split = old_split

    if visualize:
        table_viz.render()
    with open(os.path.join(viz_path, "scores_dump.json"), "w") as f:
        json.dump(scores_log, f)
    return metrics


def train(cfg, model, dataset, train_split="train", evaluate_during_training=True):
    model.train()
    old_split = dataset.split
    dataset.split = train_split
    loss_fct = CrossEntropyLoss()
    device = torch.device(
        "cuda" if torch.cuda.is_available() and cfg["SYSTEM"]["NUM_GPUS"] > 0 else "cpu"
    )
    # Accuracy, R@3, Mean Rank
    bs = cfg["VIST"]["BATCH_PER_GPU"] * cfg["SYSTEM"]["NUM_GPUS"]
    dataloader = DataLoader(
        dataset,
        batch_size=bs,
        shuffle=dataset_shuffle,
        num_workers=cfg["SYSTEM"]["NUM_WORKERS"],
        drop_last=True,
        pin_memory=False,
        collate_fn=prune_illegal_collate,
    )

    num_iter_per_epoch = vist_dataset.num_data_points_per_split["train"] // bs
    create_folders_if_necessary(
        os.path.join(save_path, "saved-checkpoints", "checkpoint_logs.txt")
    )
    if evaluate_during_training:

        viz_path = os.path.join(viz_path_root, "init-checkpoint")
        create_folders_if_necessary(os.path.join(viz_path, "test", "viz_path.txt"))
        create_folders_if_necessary(os.path.join(viz_path, "val", "viz_path.txt"))

        copyfile("assets/list.min.js", os.path.join(viz_path, "test", "list.min.js"))
        copyfile("assets/list.min.js", os.path.join(viz_path, "val", "list.min.js"))

        metrics_val = evaluate(
            cfg,
            model,
            vist_dataset,
            split="val",
            visualize=True,
            viz_path=os.path.join(viz_path, "val"),
        )
        logger.info("Init Checkpoint, metrics val: %s" % json.dumps(metrics_val))
        logger_eval.info("Init Checkpoint, metrics val: %s" % json.dumps(metrics_val))

        metrics_test = evaluate(
            cfg,
            model,
            vist_dataset,
            split="test",
            visualize=True,
            viz_path=os.path.join(viz_path, "test"),
        )
        logger.info("Init Checkpoint, metrics test: %s" % json.dumps(metrics_test))
        logger_eval.info("Init Checkpoint, metrics test: %s" % json.dumps(metrics_test))

    optim_parameters = model.parameters()
    optimizer = AdamW(optim_parameters, lr=cfg["VIST"]["LR"])
    scheduler = WarmupLinearScheduleNonZero(
        optimizer,
        warmup_steps=200,
        t_total=cfg["VIST"]["NUM_EPOCHS"] * num_iter_per_epoch,
    )
    optimizer.zero_grad()
    # evaluate before training too

    for epoch in range(cfg["VIST"]["NUM_EPOCHS"]):
        for _, iter_id, batch in batch_iter(dataloader, 1):

            for k, v in batch.items():
                batch[k] = v.to(device)

            if (
                cfg["VIST"]["MODEL_TYPE"] == "BERT"
            ):
                target_labels = torch.ones(batch["captions"].shape[0], cfg["VIST"]["NUM_DISTRACTORS"] + 1).to(device).long()
                # the first candidate is the GT caption, 0 for pos and 1 for neg
                target_labels[:, 0] = 0
                logits = model(batch)
                target_labels = target_labels.view(-1)
                loss = loss_fct(logits, target_labels)
            elif (
                cfg["VIST"]["MODEL_TYPE"] == "ROBERTA"
            ):
                target_labels = torch.ones(batch["captions"].shape[0], cfg["VIST"]["NUM_DISTRACTORS"] + 1).to(device).long()
                # the first candidate is the GT caption, 0 for pos and 1 for neg
                target_labels[:, 0] = 0
                logits = model(batch)
                target_labels = target_labels.view(-1)
                loss = loss_fct(logits, target_labels)
            elif cfg["VIST"]["MODEL_TYPE"] == "GPT2":
                target_labels = torch.ones(batch["captions"].shape[0], cfg["VIST"]["NUM_DISTRACTORS"] + 1).to(device).long()
                # the first candidate is the GT caption, 0 for pos and 1 for neg
                target_labels[:, 0] = 0
                logits = model(batch)
                target_labels = target_labels.view(-1)
                loss = loss_fct(logits, target_labels)
            elif cfg["VIST"]["MODEL_TYPE"] == "LSTM":
                loss, _ = model(batch)

            else:
                raise NotImplementedError()

            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            if iter_id % cfg["VIST"]["LOG_EVERY"] == 0:
                logger.info("iter id: %d loss: %f" % (iter_id, loss.item()))

        if evaluate_during_training:
            idx = epoch * num_iter_per_epoch

            viz_path = os.path.join(viz_path_root, "iter_%d" % idx)
            create_folders_if_necessary(os.path.join(viz_path, "test", "viz_path.txt"))
            create_folders_if_necessary(os.path.join(viz_path, "val", "viz_path.txt"))

            copyfile(
                "assets/list.min.js", os.path.join(viz_path, "test", "list.min.js")
            )
            copyfile("assets/list.min.js", os.path.join(viz_path, "val", "list.min.js"))

            metrics_val = evaluate(
                cfg,
                model,
                vist_dataset,
                split="val",
                visualize=True,
                viz_path=os.path.join(viz_path, "val"),
            )
            logger.info("epoch: %d, metrics val: %s" % (epoch, json.dumps(metrics_val)))
            logger_eval.info(
                "epoch: %d, metrics val: %s" % (epoch, json.dumps(metrics_val))
            )
            metrics_test = evaluate(
                cfg,
                model,
                vist_dataset,
                split="test",
                visualize=True,
                viz_path=os.path.join(viz_path, "test"),
            )
            logger.info(
                "epoch: %d, metrics test: %s" % (epoch, json.dumps(metrics_test))
            )

            logger_eval.info(
                "epoch: %d, metrics test: %s" % (epoch, json.dumps(metrics_test))
            )

        # save checkpoint
        if epoch % cfg["VIST"]["SAVE_EVERY_N_EPOCHS"] == 0:

            if isinstance(model, torch.nn.DataParallel):
                model_state_dict = model.module.state_dict()
            else:
                model_state_dict = model.state_dict()
            torch.save(
                {
                    "model_state_dict": model_state_dict,
                    "epoch": epoch,
                },
                os.path.join(save_path, "saved-checkpoints", "checkpoint-%d" % epoch),
            )

    dataset.split = old_split

def read_command_line(argv=None):
    parser = ArgumentParser(description="Future Prediction using Language")
    parser.add_argument(
        "--config_path", help="Path to yaml config file", default="configs/debug.yaml"
    )
    try:
        params = vars(parser.parse_args(args=argv))
    except IOError as msg:
        parser.error(str(msg))
    return params


if __name__ == "__main__":
    # get config file
    cmd_line = read_command_line()
    cfg = get_cfg_defaults()
    cfg.merge_from_file(cmd_line["config_path"])
    cfg.freeze()
   
    # setup logging

    # create directories for logging and visualizations
    os.makedirs("checkpoints", exist_ok=True)
    save_path = get_dir(cfg["SYSTEM"]["OUT_DIR"], cfg["SYSTEM"]["SEED"], "checkpoints")

    create_folders_if_necessary(os.path.join(save_path, "experiment.yaml"))
    with open(os.path.join(save_path, "experiment.yaml"), "w") as f:
        f.write(cfg.dump())

    viz_path_root = os.path.join(save_path, "viz")
    eval_path_root = os.path.join(save_path, "eval")

    logger.add(
        os.path.join(save_path, "file_{time}.log"),
        level=cfg["SYSTEM"]["DEBUGGING_LEVEL"],
    )
    logger.add(
        os.path.join(save_path, "eval_{time}.log"),
        level=cfg["SYSTEM"]["DEBUGGING_LEVEL"],
        filter=lambda record: record["extra"]["task"] == "eval",
    )
    logger.add(
        sys.stderr,
        format="{time} {level} {message}",
        level=cfg["SYSTEM"]["DEBUGGING_LEVEL"],
    )

    logger_eval = logger.bind(task="eval")
    logger = logger.bind(task="not-eval")

    logger.info("Save Path: %s" % save_path)

    seed(cfg["SYSTEM"]["SEED"])

    device = torch.device(
        "cuda" if torch.cuda.is_available() and cfg["SYSTEM"]["NUM_GPUS"] > 0 else "cpu"
    )

    # TODO: Create a factory class to initialize tokenizers and model
    model = None
    tokenizer = None
    if cfg["VIST"]["MODEL_TYPE"] == "GPT2":
        if cfg["VIST"]["MODEL_PATH"]:
            backbone = GPT2ForSequenceClassification.from_pretrained(cfg["VIST"]["MODEL_PATH"], num_labels=2)
        else:
            logger.warning("Random model initialized")
            backbone_config = GPT2Config(num_labels=2)
            backbone = GPT2ForSequenceClassification(backbone_config)
        backbone.config.pad_token_id = backbone.config.eos_token_id
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        model = GPT2Captions(backbone, tokenizer, cfg)

    elif cfg["VIST"]["MODEL_TYPE"] == "BERT":
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        if cfg["VIST"]["MODEL_PATH"]:
            backbone_bert_model = BertForNextSentencePrediction.from_pretrained(
                cfg["VIST"]["MODEL_PATH"],
                hidden_dropout_prob=cfg["VIST"]["DROPOUT"],
                attention_probs_dropout_prob=cfg["VIST"]["DROPOUT"],
            )
        else:
            logger.warning("Random model initialized")
            backbone_config = BertConfig(hidden_dropout_prob=cfg["VIST"]["DROPOUT"],
                attention_probs_dropout_prob=cfg["VIST"]["DROPOUT"])
            backbone_bert_model = BertForNextSentencePrediction(backbone_config)
            
        model = BertCaptions(backbone_bert_model,tokenizer,cfg)

    elif cfg["VIST"]["MODEL_TYPE"] == "ROBERTA":
        tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
        if cfg["VIST"]["MODEL_PATH"]:
            backbone_bert_model = RobertaForSequenceClassification.from_pretrained(
                cfg["VIST"]["MODEL_PATH"],
                hidden_dropout_prob=cfg["VIST"]["DROPOUT"],
                attention_probs_dropout_prob=cfg["VIST"]["DROPOUT"],
                num_labels=2,
            )
        else:
            logger.warning("Random model initialized")
            # backbone_config = RobertaConfig(hidden_dropout_prob=cfg["VIST"]["DROPOUT"],
            #     attention_probs_dropout_prob=cfg["VIST"]["DROPOUT"],
            #     num_labels=2)
            backbone_config = RobertaConfig(num_labels=2)
            # backbone_bert_model = RobertaForSequenceClassification.from_pretrained(
            #     cfg["VIST"]["MODEL_PATH"],
            #     hidden_dropout_prob=cfg["VIST"]["DROPOUT"],
            #     attention_probs_dropout_prob=cfg["VIST"]["DROPOUT"],
            #     num_labels=2,
            # )
            backbone_bert_model = RobertaForSequenceClassification(backbone_config)

        model = RobertaCaptions(backbone_bert_model,tokenizer,cfg)

    elif cfg["VIST"]["MODEL_TYPE"] == "NearestNeighbor":
        model = NearestNeighborResnet(cfg)
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    elif cfg["VIST"]["MODEL_TYPE"] == "LSTM":
        model = LSTM(cfg)
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    else:
        raise NotImplementedError()

    load_images = (
        cfg["VIST"]["MODEL_TYPE"] == "NearestNeighbor"
        or cfg["VIST"]["MODEL_TYPE"] == "LSTM"
    )
    vist_dataset = VISTDataset(cfg["VIST"], tokenizer, load_images=load_images)
    dataset_shuffle = False if cfg["VIST"]["OVERFIT"] else True
    vist_dataset.split = "train"

    # model.load_state_dict(torch.load("/u/xxxx/world-models/checkpoints/dii_finetuned_hard_roberta/123/saved-checkpoints/checkpoint-2")["model_state_dict"])

    model = model.to(device)

    train(cfg, model, vist_dataset, train_split="train")
