import torch
import numpy as np
import argparse
import os
from utils import Logger, compute_accuracy, limit_tensorflow_memory_usage, CsvWriter, predict_by_max_logit,\
    compute_accuracy_from_predictions, get_mean_percent_and_95_confidence_interval
from image_folder_reader import ImageFolderReader
from tf_dataset_reader import TfDatasetReader
from bit_resnet import KNOWN_MODELS
from dataset import vtab_datasets


def main():
    learner = Learner()
    learner.run()


class Learner:
    def __init__(self):
        self.args = self.parse_command_line()
        self.logger = Logger(self.args.checkpoint_dir, 'log.txt')
        self.logger.print_and_log("Options: %s\n" % self.args)
        self.logger.print_and_log("Checkpoint Directory: %s\n" % self.args.checkpoint_dir)
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.accuracy_fn = compute_accuracy
        self.csv_writer = CsvWriter(
            file_path=os.path.join(self.args.checkpoint_dir, 'results.csv'),
            header=['dataset', 'run 1', 'run 2', 'run 3', 'mean']
        )
    """
    Command line parser
    """
    def parse_command_line(self):
        parser = argparse.ArgumentParser()

        parser.add_argument("--model", choices=['BiT-M-R50x1'], default='BiT-M-R50x1', help="Model to test.")
        parser.add_argument("--path_to_models", default=None, help="Path to pretrained BiT models.")
        parser.add_argument("--download_path_for_tensorflow_datasets", default=None,
                            help="Path to download the tensorflow datasets.")
        parser.add_argument("--download_path_for_sun397_dataset", default=None,
                            help="Path to download the sun397 dataset.")
        parser.add_argument("--checkpoint_dir", "-c", default='../checkpoints',
                            help="Directory to save checkpoint to.")
        parser.add_argument("--batch_size", "-b", type=int, default=500, help="Batch size.")
        args = parser.parse_args()

        return args

    def init_model(self, model_name, num_classes):
        model = KNOWN_MODELS[self.args.model](head_size=num_classes)
        model.load_from(np.load(os.path.join(self.args.path_to_models, model_name)))
        return model

    def run(self):
        limit_tensorflow_memory_usage(1024)
        self.test_vtab()

    def test_vtab(self):
        self.logger.print_and_log("")  # add a blank line

        context_set_size = 1  # don't need the context set, but make the reader happy
        num_models = 3

        with torch.no_grad():
            for dataset in vtab_datasets:
                if dataset['enabled'] is False:
                    continue

                accuracies = []
                for model_index in range(num_models):
                    model = self.init_model(
                        model_name="{}-run{}-{}.npz".format(self.args.model, model_index, dataset['model_name']),
                        num_classes=dataset['num_classes']
                    )

                    image_size = dataset['bit_image_size']

                    if dataset['name'] == "sun397":  # use the image folder reader as the tf reader is broken for sun397
                        dataset_reader = ImageFolderReader(
                            path_to_images=self.args.download_path_for_sun397_dataset,
                            context_batch_size=context_set_size,
                            target_batch_size=self.args.batch_size,
                            image_size=image_size,
                            device=self.device,
                            osr=False)
                    else:  # use the tensorflow dataset reader
                        dataset_reader = TfDatasetReader(
                            dataset=dataset['name'],
                            task=dataset['task'],
                            context_batch_size=context_set_size,
                            target_batch_size=self.args.batch_size,
                            path_to_datasets=self.args.download_path_for_tensorflow_datasets,
                            num_classes=dataset['num_classes'],
                            image_size=image_size,
                            device=self.device,
                            osr=False
                        )

                    # test the model
                    model.to(self.device)
                    model.eval()
                    test_set_size = dataset_reader.get_target_dataset_length()
                    num_batches = int(np.ceil(float(test_set_size) / float(self.args.batch_size)))

                    with torch.no_grad():
                        labels = []
                        predictions = []
                        for batch in range(num_batches):
                            batch_images, batch_labels = dataset_reader.get_target_batch()
                            logits = model(batch_images)
                            predictions.append(predict_by_max_logit(logits))
                            labels.append(batch_labels)
                            del logits
                        predictions = torch.hstack(predictions)
                        labels = torch.hstack(labels)
                        accuracy = compute_accuracy_from_predictions(predictions, labels)
                        accuracies.append(accuracy.cpu())

                    if dataset['task'] is None:
                        self.logger.print_and_log('{0:}: {1:3.1f}'.format(dataset['name'], accuracy * 100.0))
                    else:
                        self.logger.print_and_log('{0:} {1:}: {2:3.1f}'.format(dataset['name'], dataset['task'], accuracy * 100.0))

                mean, confidence = get_mean_percent_and_95_confidence_interval(accuracies)
                self.csv_writer.write_row(
                    [
                        "{0:}".format(dataset['name']),
                        "{0:1.3f}".format(accuracies[0]),
                        "{0:1.3f}".format(accuracies[1]),
                        "{0:1.3f}".format(accuracies[2]),
                        "{0:3.1f}+/-{1:2.1f}".format(mean, confidence)
                    ]
                )


if __name__ == "__main__":
    main()
