"""Script to finetune unimodal image model on multiclass classification of image labels.

We use this script to finetune the pretrained ViT image model on the CIFAR-10 dataset to predict
image labels. We do not introduce label flips for finetuning, we simply train on noisy labels.
We load pretrained ViT weights from the Huggingface library.
"""
from transformers import ViTForImageClassification, AdamW
from transformers import CLIPVisionModel
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
import pytorch_lightning as pl
import torch.nn as nn
import os
# Official implementation from baseline repo.
from lib.metrics.confident_learning import get_val_confident_learning
from torchmetrics.utilities import dim_zero_cat
import logging
import argparse
from lib.datasets.utils import (
    get_dataset,
)
from torchmetrics import Metric
import pandas as pd
from tqdm.auto import tqdm
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
from lib.models.utils import algorithm_class_from_scratch
import torchmetrics
from pytorch_lightning import loggers as pl_loggers

CAPTION_DATASETS = ["mscoco", "mimiccxr_caption", "flickr30k", "mmimdb"]
CLASSIFICATION_DATASETS = ['cifar10','cifar100','stanford_cars','mini_imagenet']


class CorruptionMetrics(Metric):
    def __init__(self, dataset, **kwargs):
        super().__init__(**kwargs)
        self.add_state("is_mislabel", default=[], dist_reduce_fx="cat")
        self.add_state("labels", default=[], dist_reduce_fx="cat")
        self.add_state("logits", default=[], dist_reduce_fx="cat")
        if dataset in CAPTION_DATASETS:
            self.is_caption_dataset = True
        else:
            self.is_caption_dataset = False

    def update(self, is_mislabel, labels, logits) -> None:
        self.is_mislabel.append(is_mislabel.detach().cpu())
        self.labels.append(labels.detach().cpu())
        self.logits.append(logits.detach().cpu())

    def compute(self):
        is_mislabel = dim_zero_cat(self.is_mislabel)
        labels = dim_zero_cat(self.labels)
        logits = dim_zero_cat(self.logits)
        label_corruption_f1, label_corruption_acc, _ = get_val_confident_learning(is_mislabel, labels.numpy(), logits.numpy())
        return label_corruption_f1, label_corruption_acc


class ViTLightningModule(pl.LightningModule):
    def __init__(self, args):
        super(ViTLightningModule, self).__init__()
        if args.dataset in CLASSIFICATION_DATASETS:
            is_caption_dataset = False
        elif args.dataset in CAPTION_DATASETS:
            is_caption_dataset = True
        else:
            raise ValueError("Dataset not supported: {}".format(args.dataset))
        num_labels = get_num_labels(args)
        self.is_caption_dataset = is_caption_dataset
        self.model_name = args.clip_model
        if self.model_name in ["clip", "huggingface_clip"]:
            self.vit = get_vision_model(args)
            self.projection = nn.Linear(768, num_labels)
            self.softmax = nn.Softmax()
        elif self.model_name == "biomed_clip":
            self.vit = get_vision_model(args)
            self.projection = nn.Linear(512, num_labels)
            self.softmax = nn.Softmax()
        else:
            raise ValueError("Model not supported: {}".format(model))
        self.val_corruption = CorruptionMetrics(args.dataset)
        self.test_corruption = CorruptionMetrics(args.dataset)


    def forward(self, pixel_values):
        if self.model_name in ["clip", "huggingface_clip"]:
            outputs = self.vit(pixel_values=pixel_values)
            outputs = self.projection(outputs.pooler_output)
            logits = self.softmax(outputs)
        elif self.model_name == "biomed_clip":
            outputs = self.vit(pixel_values)
            outputs = self.projection(outputs)
            logits = self.softmax(outputs)
        else:
            raise NotImplementedError()
            # logits = self.vit(pixel_values=pixel_values).logits
        return logits

    def get_logits_and_labels(self, batch):
        pixel_values, clean_labels, labels = batch['pixel_values'], batch['clean_labels'], batch['labels']
        logits = self(pixel_values)
        return logits, clean_labels, labels

    def common_step(self, logits, labels):
        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/len(labels)

        return loss, accuracy

    def is_mislabel(self, clean_labels, labels):
        if self.is_caption_dataset:
            return clean_labels == -1
        else:
            return clean_labels != labels

    def training_step(self, batch, batch_idx):
        logits, clean_labels, labels = self.get_logits_and_labels(batch)
        loss, accuracy = self.common_step(logits, labels)
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)
        return loss

    def validation_step(self, batch, batch_idx):
        logits, clean_labels, labels = self.get_logits_and_labels(batch)
        loss, accuracy = self.common_step(logits, labels)
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)
        is_mislabel = self.is_mislabel(clean_labels, labels)
        self.val_corruption.update(is_mislabel, labels, logits)
        return loss

    def test_step(self, batch, batch_idx):
        logits, clean_labels, labels = self.get_logits_and_labels(batch)
        loss, accuracy = self.common_step(logits, labels)
        self.log("test_loss", loss, on_epoch=True)
        self.log("test_accuracy", accuracy, on_epoch=True)
        is_mislabel = self.is_mislabel(clean_labels, labels)
        self.test_corruption.update(is_mislabel, labels, logits)
        return loss

    def on_validation_epoch_end(self):
        val_f1, val_acc = self.val_corruption.compute()
        self.log("val_f1_corruption", val_f1)
        self.log("val_acc_corruption", val_acc)
        self.val_corruption.reset()

    def on_test_epoch_end(self):
        test_f1, test_acc = self.test_corruption.compute()
        self.log("test_f1_corruption", test_f1)
        self.log("test_acc_corruption", test_acc)
        self.test_corruption.reset()

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        optimizer = AdamW(self.parameters(), lr=args.lr, weight_decay=args.wd)
        #  {(2,500, 200), (10,000, 500)}
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

    def test_dataloader(self):
        return test_dataloader


def collate_fn(examples):
    pixel_values = []
    labels = []
    clean_labels = []

    for example in examples:
        pixels, clean_label, label = example
        pixel_values.append(pixels)
        labels.append(label)
        clean_labels.append(clean_label)

    pixel_values = torch.stack(pixel_values)
    labels = torch.Tensor(labels).type(torch.LongTensor)
    clean_labels = torch.Tensor(clean_labels).type(torch.LongTensor)
    return {"pixel_values": pixel_values, "labels": labels, "clean_labels": clean_labels}


def get_labels_df(model, dataset, dataloader, split, flip_type):
    method = 'confident_learning'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    model.eval()
    logits_list = []
    clean_class_labels = []
    noisy_class_labels = []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            pixel_values, labels, clean_labels = batch['pixel_values'].to(device), batch['labels'], batch['clean_labels']
            logits = model(pixel_values).detach().cpu()
            logits_list.append(logits)
            clean_class_labels.append(clean_labels.detach().cpu())
            noisy_class_labels.append(labels.detach().cpu())
    logits_list = torch.cat(logits_list)
    clean_class_labels = torch.cat(clean_class_labels)
    noisy_class_labels = torch.cat(noisy_class_labels)
    is_mislabel = model.is_mislabel(clean_class_labels, noisy_class_labels)
    f1, acc, pred_is_mislabel = get_val_confident_learning(is_mislabel, noisy_class_labels.numpy(), logits_list.numpy())

    df = pd.DataFrame.from_dict(dict(
        raw_score=is_mislabel,
        label_error_pred=is_mislabel,
        clean_labels=clean_class_labels.numpy(),
        noisy_labels=noisy_class_labels.numpy()))
    df['method'] = method
    df['dataset'] = dataset
    df['split'] = split
    df['flip_type'] = flip_type
    df['pred_is_mislabel'] = pred_is_mislabel
    return df


def get_num_labels(args):
    if args.dataset == "cifar10":
        num_labels = 10
    elif (args.dataset == "cifar100") or (args.dataset == "mini_imagenet"):
        num_labels = 100
    elif args.dataset == "stanford_cars":
        num_labels = 196
    elif args.dataset in CAPTION_DATASETS:
        num_labels = args.kmeans_k
    else:
        raise ValueError("Dataset not supported: {}".format(args.dataset))
    return num_labels


def get_vision_model(args):
    clip_model = args.clip_model
    if clip_model == "clip":
        pretrained_model_name = "openai/clip-vit-base-patch32"
        vision_model = CLIPVisionModel.from_pretrained(pretrained_model_name, num_labels=get_num_labels(args))
    else:
        model = algorithm_class_from_scratch(
            clip_model, text_base_name='openai/clip-vit-base-patch32', img_base=None, return_tokenizer=False
        )
        if clip_model == "huggingface_clip":
            vision_model = model.vision_model
        elif clip_model == "biomed_clip":
            vision_model = model.visual
    return vision_model


def check_flip_type(args):
    if args.flip_type in ["real", "symmetric", "asymmetric"]:
        assert args.dataset not in CAPTION_DATASETS, "Flip type not supported for Caption Datasets"
    elif args.flip_type in ["cat", "noun", "random"]:
        assert args.dataset in CAPTION_DATASETS, "Flip type not supported for non Caption datasets"
    else:
        raise ValueError("Flip type not supported: {}".format(args.flip_type))


def get_args_parser():
    parser = argparse.ArgumentParser(description="Multimodal distance metric")
    parser.add_argument("--exp_name", type=str, default="google_vit_test", required=True)

    # training
    parser.add_argument(
        "--linear_probe", type=bool, default=False, choices=[True, False]
    )
    parser.add_argument(
        "--dataset", type=str, default="cifar10", choices=CLASSIFICATION_DATASETS + CAPTION_DATASETS
    )
    parser.add_argument(
        "--flip_type", type=str, default="real", choices=["real", "symmetric", "asymmetric", "cat", "noun", "random"]
    )
    # parser.add_argument(
    #     "--img_base_name",
    #     type=str,
    #     default="clipvisionmodelvit",
    #     choices=["clipvisionmodelvit"],
    # )  # TODO: maybe add 'clipvisionmodel' - Resnet
    parser.add_argument("--output_folder_name", type=str, default="cl_exps")
    # others
    parser.add_argument("--data_dir", type=str, default="./data")
    parser.add_argument("--output_dir", type=str, default="./output")
    parser.add_argument(
        "--data_seed",
        type=int,
        default=0,
        help='Seed for random hparams (0 for "default hparams")',
    )
    parser.add_argument("--seed", type=int, default=0, help="Seed for everything else")
    parser.add_argument("--kmeans_k", type=int, default=100, help="Seed for everything else")

    # early stopping
    parser.add_argument(
        "--es_patience",
        type=int,
        default=5,
        help="Stop after this many checkpoints w/ no improvement",
    )
    parser.add_argument(
        "--wd",
        type=float,
        default=0.01,
        help="Weight decay"
    )

    # hparams
    parser.add_argument("--lr", type=float, default=1e-3) # Google VIT default is 5e-5, CLIP is around 1e-3 maybe
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=30)

    parser.add_argument(
        "--debug",
        default=False,
        action="store_true",
        help="debug mode - run one batch and 4 epochs",
    )
    parser.add_argument("--clip_model", type=str, default="clip", choices=["clip", 'biomed_clip', 'huggingface_clip'])
    return parser

def add_store_name(args):
    args.store_name = (
        f"{args.exp_name}_{args.dataset}_{args.clip_model}_flip{args.flip_type}_data_seed{args.data_seed}_seed{args.seed}_lr{args.lr}_wd{args.wd}_batch{args.batch_size}_epochs{args.epochs}_es_patience{args.es_patience}"
    )
    if args.linear_probe:
        args.store_name += "_linear_probe"
    if args.dataset in CAPTION_DATASETS:
        args.store_name += "_kmeansk{}".format(args.kmeans_k)


def get_trainer(args):
    checkpoint = pl.callbacks.ModelCheckpoint(
        save_last=True,
        monitor="validation_loss",
        mode="min",
        save_top_k=1,
        dirpath=os.path.join(args.output_dir, args.output_folder_name, args.store_name),
    )

    early_stop_callback = EarlyStopping(
        monitor='validation_loss',
        patience=args.es_patience,
        strict=False,
        verbose=False,
        mode='min'
    )

    output_dir = os.path.join(args.output_dir, args.output_folder_name, args.store_name)
    logger = pl_loggers.WandbLogger(project="label_noise",
                                    name=args.dataset + "_" + args.exp_name + "_lr" + str(args.lr) + "_k" + str(args.kmeans_k),
                                    save_dir=output_dir, config=args, group=args.dataset+"_"+args.exp_name)
    trainer = Trainer(logger=logger, accelerator='gpu', devices=1,
                       callbacks=[early_stop_callback, checkpoint], num_sanity_val_steps=15,
                       max_epochs=args.epochs)
    return trainer


def create_datasets(args):
    if args.dataset in CAPTION_DATASETS:
        cluster_kwargs = {'n_clusters': args.kmeans_k}
        train_dataset, val_dataset, test_dataset = get_dataset(
            args.dataset, args.data_seed, noisy_labels=True,
            flip_type=args.flip_type, cluster_text=True,
            cluster_kwargs=cluster_kwargs)
    else:
        train_dataset, val_dataset, test_dataset = get_dataset(args.dataset, args.data_seed, noisy_labels=True, flip_type=args.flip_type)
    return train_dataset, val_dataset, test_dataset


if __name__ == "__main__":
    parser = get_args_parser()
    args = parser.parse_args()

    check_flip_type(args)
    add_store_name(args)

    model = ViTLightningModule(args=args)

    train_dataset, val_dataset, test_dataset = create_datasets(args)
    

    n_workers = 4

    train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.batch_size, num_workers=n_workers)
    val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=args.batch_size, num_workers=n_workers)
    test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=args.batch_size, num_workers=n_workers)

    trainer = get_trainer(args)
    trainer.fit(model)

    results = trainer.validate(ckpt_path=trainer.checkpoint_callback.best_model_path)
    output_dir = os.path.join(args.output_dir, args.output_folder_name, args.store_name)
    pd.to_pickle(results, os.path.join(output_dir, "val_results.pkl"))

    results = trainer.test(ckpt_path=trainer.checkpoint_callback.best_model_path)
    output_dir = os.path.join(args.output_dir, args.output_folder_name, args.store_name)
    pd.to_pickle(results, os.path.join(output_dir, "test_results.pkl"))

    train_dataloader = DataLoader(train_dataset, shuffle=False, drop_last=False, collate_fn=collate_fn, batch_size=args.batch_size, num_workers=n_workers)
    val_dataloader = DataLoader(val_dataset, shuffle=False, drop_last=False, collate_fn=collate_fn, batch_size=args.batch_size, num_workers=n_workers)
    test_dataloader = DataLoader(test_dataset, shuffle=False, drop_last=False, collate_fn=collate_fn, batch_size=args.batch_size, num_workers=n_workers)
    
    dataloader = train_dataloader

    model = ViTLightningModule.load_from_checkpoint(args=args, checkpoint_path=trainer.checkpoint_callback.best_model_path)

    train_df = get_labels_df(model, args.dataset, dataloader, "train", args.flip_type)

    dataloader = val_dataloader
    val_df = get_labels_df(model, args.dataset, dataloader, "val", args.flip_type)

    dataloader = test_dataloader
    test_df = get_labels_df(model, args.dataset, dataloader, "test", args.flip_type)

    all_df = pd.concat([train_df.reset_index(drop=False), val_df.reset_index(drop=False), test_df.reset_index(drop=False)])

    train_df.to_pickle(os.path.join(output_dir, "train.pkl"))
    test_df.to_pickle(os.path.join(output_dir, "test.pkl"))
    val_df.to_pickle(os.path.join(output_dir, "val.pkl"))
    all_df.to_pickle(os.path.join(output_dir, "all.pkl"))

    print("Success")
