import sys
import argparse
import os
import torch
import shutil
from torch.utils.tensorboard import SummaryWriter
from utils.utils import set_logger

class Arguments:
    def __init__(self, args):
        # === Data Related Parameters ===
        self.root = args.root
        self.dataset = args.dataset
        self.train_splits = args.train_splits

        self.resize_to = args.resize_to
        
        # === ViT Related Parameters ===
        self.arch = args.arch
        self.patch_size = args.patch_size
        self.dinov2 = args.dinov2
        self.supervised = args.supervised

        # === Slot Attention Related Parameters ===
        self.num_slots = args.num_slots
        self.slot_att_iter = args.slot_att_iter
        self.slot_dim = args.slot_dim
        self.merge_slots = args.merge_slots
        self.slot_merge_coeff = args.slot_merge_coeff
        self.pass_transformer = args.pass_transformer
        self.original_slot_attention = args.original_slot_attention

        # === Decoder Related Parameters ===
        self.decoder_hidden_dim = args.decoder_hidden_dim
        self.decoder_training = args.decoder_training

        # === Model Related Parameters ===
        self.N = args.N

        # === Training Related Parameters ===
        self.finetuning = args.finetuning
        self.token_drop_ratio = args.token_drop_ratio
        self.learning_rate = args.learning_rate
        self.batch_size = args.batch_size
        self.num_epochs = args.num_epochs

        # === Misc ===
        self.calculate_memory = args.calculate_memory
        self.use_checkpoint = args.use_checkpoint
        self.checkpoint_path = args.checkpoint_path
        self.validation_epoch = args.validation_epoch

        self.validate = args.validate
        self.validate_k_times = args.validate_k_times
        
        self.workers = args.workers
        self.seed = args.seed
        self.model_save_path = args.model_save_path

        self.device = torch.device(
            "cuda") if torch.cuda.is_available() else torch.device("cpu")

        self.assertions()
        self.set_loggers()

        self.print_args()

    def assertions(self):
        assert os.path.exists(os.path.join(self.root))

        if self.use_checkpoint:
            assert self.checkpoint_path is not None
            assert os.path.exists(self.checkpoint_path)

        if len(self.resize_to) == 1:
            self.resize_to = self.resize_to * 2
        assert len(self.resize_to) == 2

        assert self.N >= 1

        assert "train" in self.train_splits
        if self.dataset == "ytvis19":
            for split in self.train_splits:
                assert split in ["train", "valid", "test"]
        else:
            for split in self.train_splits:
                assert split in ["train", "valid"]

        if self.finetuning:
            assert self.use_checkpoint


    def set_loggers(self):
        if not os.path.exists(self.model_save_path):
            os.mkdir(self.model_save_path)
        
        logger_path = os.path.join(self.model_save_path, "train_log.log")
        self.logger = set_logger(logger_path)

        writer_path = os.path.join(self.model_save_path, "writer.log")
        if os.path.exists(writer_path):
            shutil.rmtree(writer_path)
        comment = f"lr: {self.learning_rate:.5f} bs: {self.batch_size}"
        self.writer = SummaryWriter(log_dir=writer_path, comment=comment)

    def print_args(self):
        assert self.logger is not None
        self.logger.info("====== Arguments ======")
        self.logger.info(f"training name: {self.model_save_path.split('/')[-1]}\n")

        self.logger.info(f"arch: {self.arch}")
        self.logger.info(f"patch_size: {self.patch_size}")
        self.logger.info(f"dinov2: {self.dinov2}")
        self.logger.info(f"supervised: {self.supervised}\n")

        self.logger.info(f"dataset: {self.dataset}\n")

        self.logger.info(f"resize_to: {self.resize_to}\n")

        self.logger.info(f"num_slots: {self.num_slots}")
        self.logger.info(f"slot_att_iter: {self.slot_att_iter}")
        self.logger.info(f"slot_dim: {self.slot_dim}")
        self.logger.info(f"merge_slots: {self.merge_slots}")
        self.logger.info(f"slot_merge_coeff: {self.slot_merge_coeff}")
        self.logger.info(f"pass_transformer: {self.pass_transformer}")
        self.logger.info(f"original_slot_attention: {self.original_slot_attention}\n")

        self.logger.info(f"decoder_hidden_dim: {self.decoder_hidden_dim}")
        self.logger.info(f"decoder_training: {self.decoder_training}\n")

        self.logger.info(f"N: {self.N}\n")

        self.logger.info(f"finetuning: {self.finetuning}")
        self.logger.info(f"token_drop_ratio: {self.token_drop_ratio}")
        self.logger.info(f"learning_rate: {self.learning_rate}")
        self.logger.info(f"batch_size: {self.batch_size}")
        self.logger.info(f"num_epochs: {self.num_epochs}")

        self.logger.info(f"calculate_memory: {self.calculate_memory}\n")

        self.logger.info(f"validate: {self.validate}")
        self.logger.info(f"validate_k_times: {self.validate_k_times}")
        self.logger.info("====== ======= ======\n")


def get_args():
    parser = argparse.ArgumentParser("SOLV")

    # === Data Related Parameters ===
    parser.add_argument('--root', type=str, required=True)

    parser.add_argument('--dataset', type=str, default="ytvis19", choices=["davis19", "ytvis19", "movi"])
    parser.add_argument('--train_splits', nargs='+', type=str, default=["train", "valid", "test"])

    parser.add_argument('--resize_to',  nargs='+', type=int, default=[336, 504])

    # === ViT Related Parameters ===
    parser.add_argument('--arch', type=str, default="vit_base")
    parser.add_argument('--patch_size', type=int, default=14)
    parser.add_argument('--dinov2', action="store_true")
    parser.add_argument('--supervised', action="store_true")

    # === Slot Attention Related Parameters ===
    parser.add_argument('--num_slots', type=int, default=8)
    parser.add_argument('--slot_att_iter', type=int, default=3)
    parser.add_argument('--slot_dim', type=int, default=128)
    parser.add_argument('--merge_slots', action="store_true")
    parser.add_argument('--slot_merge_coeff', type=float, default=0.12)
    parser.add_argument('--pass_transformer', action="store_true")
    parser.add_argument('--original_slot_attention', action="store_true")


    # === Decoder Related Parameters ===
    parser.add_argument('--decoder_hidden_dim', type=int, default=1024)
    parser.add_argument('--decoder_training', action="store_true")

    # === Model Related Parameters ===
    parser.add_argument('--N', type=int, default=2)

    # === Training Related Parameters ===
    parser.add_argument('--finetuning', action="store_true")
    parser.add_argument('--token_drop_ratio', type=float, default=0.5)
    parser.add_argument('--learning_rate', type=float, default=4e-4)
    parser.add_argument('--batch_size', type=int, default=48)
    parser.add_argument('--num_epochs', type=int, default=180)

    # === Misc ===
    parser.add_argument('--calculate_memory', action="store_true")
    parser.add_argument('--use_checkpoint', action="store_true")
    parser.add_argument('--checkpoint_path', type=str, default=None)
    parser.add_argument('--validation_epoch', type=int, default=1)

    parser.add_argument('--validate', action="store_true")
    parser.add_argument('--validate_k_times', type=int, default=5)

    parser.add_argument('--workers', type=int, default=1)
    parser.add_argument('--seed', type=int, default=32)
    parser.add_argument('--model_save_path', type=str, required=True)

    args = parser.parse_args()

    arg_object = Arguments(args)
    return arg_object