import ast
from pathlib import Path

import ruamel.yaml as yaml
from lutils import openf, writef


class Config:
    def __init__(self, config_name):
        configs = list(Path("output").glob("**/config.yaml")) + list(
            Path("output").glob("*/*/config.yaml")
        )
        self.name2pth = {pth.parent.name: pth for pth in configs}
        assert config_name in self.name2pth.keys(), f"{config_name} not found"
        self.config_path = self.name2pth[config_name]

        self.config = self.open_config(self.config_path)
        self.epoch = self.get_epoch(self.config_path.parent / "log.txt")

        self.vit = self.config["vit"][0].upper()
        self.name = config_name
        self.architecture = self.get_architecture()
        self.checkpoint = self.config["pretrained"]
        self.pretraining = self.get_pretraining(self.checkpoint)
        self.pretrained = self.get_pretrained_name(self.checkpoint)
        self.model = self.get_model(self.checkpoint)
        self.dataset, self.edit = self.get_data(self.config)
        self.batch = self.config["batch_size_train"]
        self.lr = self.config["init_lr"]
        self.beta = self.config["beta"] if "beta" in self.config.keys() else ""

        self.frame = self.config["frame"] if "frame" in self.config.keys() else ""

    @staticmethod
    def open_config(pth):
        return yaml.load(open(pth, "r"), Loader=yaml.Loader)

    def get_architecture(self):
        if "fViT" in self.name:
            # assert "train_vit" in config.keys()
            assert "train_vit" in self.config.keys(), "train_vit not in config"
            if self.config["train_vit"] == "True":
                return f"ViT-{self.vit}/16 target embeddings"
            else:
                return f"ViT-{self.vit}/16 freeze ViT"
        else:
            return f"ViT-{self.vit}/16"

    @staticmethod
    def is_url(url_or_filename):
        from urllib.parse import urlparse

        parsed = urlparse(url_or_filename)
        return parsed.scheme in ("http", "https")

    @staticmethod
    def get_pretraining(checkpoint):
        if Config.is_url(checkpoint):
            return ""
        else:
            return Path(checkpoint).parent.name

    @staticmethod
    def get_model(checkpoint):
        if Config.is_url(checkpoint):
            return checkpoint.split("/")[-1].split(".")[0]
        else:
            return Config(Path(checkpoint).parent.name).model

    @staticmethod
    def get_pretrained_name(checkpoint):
        if Config.is_url(checkpoint):
            return ""
        else:
            return Path(checkpoint).parent.name

    @staticmethod
    def get_data(config):
        dataset = config["dataset"]
        if dataset == "cirr":
            dataset = "CIRR"
        elif dataset == "webvid":
            dataset = "WebVid"
        else:
            raise ValueError(f"Dataset {dataset} not supported")

        # if only one -
        if config["data"].count("-") == 0:
            size = config["data"]
            edit = ""
        elif config["data"].count("-") == 1:
            size, edit = config["data"].split("-")
            dataset = f"{dataset} - {size}"
        else:
            size, edit = config["data"].split("-", 1)
            dataset = f"{dataset} - {size}"

        if "iterate" in config.keys():
            edit = f"{edit} - Iter {config['iterate']}"

        return dataset, edit

    @staticmethod
    def get_epoch(log_pth):
        if not log_pth.exists():
            return ""
        logs = openf(log_pth)
        last_epoch = ast.literal_eval(logs[-1])["best_epoch"]
        return last_epoch


def main(experiment):
    config = Config(experiment)

    results = []
    results.append(f"Experiment: {config.name}")
    results.append(f"Architecture: {config.architecture}")
    results.append(f"Pretraining: {config.pretraining}")
    results.append(f"Dataset: {config.dataset}")
    results.append(f"Edit: {config.edit}")
    results.append(f"Batch: {config.batch}")
    results.append(f"Learning rate: {config.lr}")
    results.append(f"Model: {config.model}")
    results.append(f"Pretrained: {config.pretrained}")
    results.append(f"Frame: {config.frame}")
    results.append(f"Epoch: {config.epoch}")
    results.append(f"Beta: {config.beta}")
    result_pth = config.config_path.parent / "results.txt"
    writef(results, result_pth)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="")
    parser.add_argument("experiment_name", help="", type=str)
    args = parser.parse_args()

    main(args.experiment_name)
