import hydra
import logging
import sys
import os
import json
import easydict

from utils.hydra_utils import initialize_config, infer_and_assert_hydra_config

import pytorch_lightning as pl
import torch

from dataset import get_dataset, get_collate_fn
from model import get_shot_encoder, get_contextual_relation_network
from loss import get_loss

from pretrain_wrapper import PretrainingWrapper
from pytorch_lightning.strategies import DDPStrategy

from finetune_wrapper import FinetuningWrapper

def init_hydra_config(mode: str):
    
    logging.getLogger().setLevel(logging.DEBUG)
    # logging.getLogger().setLevel(logging.INFO)

    overrides = sys.argv[1:]
    # overrides = []
    logging.info("##### overrides: {}".format(overrides))

    with hydra.initialize_config_module(config_module="cfg", version_base="1.1"):
        # cfg = hydra.compose(mode, overrides=overrides)
        if mode == "extract":
            cfg = hydra.compose("pretrain", overrides=overrides)
        else:
            cfg = hydra.compose(mode, overrides=overrides)

    cfg = initialize_config(cfg, mode=mode)

    return cfg


def apply_random_seed(cfg):
    if "SEED" in cfg and cfg.SEED >= 0:
        pl.seed_everything(cfg.SEED, workers=True)
    ### check SEED INFO for reproduce
    ### https://pytorch.org/docs/stable/data.html#torch.utils.data.get_worker_info


def load_pretrained_config(cfg):
    
    # load_from = cfg.LOAD_FROM
    # ckpt_path = cfg.CKPT_PATH
    # dataset = cfg.DATASET

    # num_proc_per_node = cfg.DISTRIBUTED.NUM_PROC_PER_NODE
    

    # with open(os.path.join(ckpt_path, load_from, "config.json"), "r") as fopen:
    #     pretrained_cfg = json.load(fopen)
    #     pretrained_cfg = easydict.EasyDict(pretrained_cfg)

    # pretrained_cfg.MODE = cfg.MODE
    # pretrained_cfg.DISTRIBUTED.NUM_NODES = 1
    # pretrained_cfg.LOAD_FROM = load_from
    # pretrained_cfg.DATASET = dataset

    # pretrained_cfg.DISTRIBUTED.NUM_PROC_PER_NODE = num_proc_per_node

    # ### only for test -> if overrides -> apply
    # # pretrained_cfg.LOSS.sampling_method.name = cfg.LOSS.sampling_method.name
    # # pretrained_cfg.LOSS.sampling_method.params.asymmetric.neighbor_left = cfg.LOSS.sampling_method.params.asymmetric.neighbor_left
    # # pretrained_cfg.LOSS.sampling_method.params.asymmetric.neighbor_right = cfg.LOSS.sampling_method.params.asymmetric.neighbor_right

    # pretrained_cfg = infer_and_assert_hydra_config(pretrained_cfg)

    # pretrained_cfg.TEST.BATCH_SIZE.effective_batch_size = cfg.TEST.BATCH_SIZE.effective_batch_size

    # cfg = pretrained_cfg

    if cfg.MODE in ["pretrain", "extract"]:
        load_from = cfg.LOAD_FROM
        ckpt_path = cfg.CKPT_PATH
        dataset = cfg.DATASET

        data_path = cfg.DATA_PATH
        feat_path = cfg.FEAT_PATH
        img_path = cfg.IMG_PATH
    
        num_proc_per_node = cfg.DISTRIBUTED.NUM_PROC_PER_NODE
        
    
        with open(os.path.join(ckpt_path, load_from, "config.json"), "r") as fopen:
            pretrained_cfg = json.load(fopen)
            pretrained_cfg = easydict.EasyDict(pretrained_cfg)
    
        pretrained_cfg.MODE = cfg.MODE
        pretrained_cfg.DISTRIBUTED.NUM_NODES = 1
        pretrained_cfg.LOAD_FROM = load_from
        pretrained_cfg.DATASET = dataset

        pretrained_cfg.DATA_PATH = data_path
        pretrained_cfg.FEAT_PATH = feat_path
        pretrained_cfg.IMG_PATH = img_path
    
        pretrained_cfg.DISTRIBUTED.NUM_PROC_PER_NODE = num_proc_per_node
    
        ### only for test -> if overrides -> apply
        # pretrained_cfg.LOSS.sampling_method.name = cfg.LOSS.sampling_method.name
        # pretrained_cfg.LOSS.sampling_method.params.asymmetric.neighbor_left = cfg.LOSS.sampling_method.params.asymmetric.neighbor_left
        # pretrained_cfg.LOSS.sampling_method.params.asymmetric.neighbor_right = cfg.LOSS.sampling_method.params.asymmetric.neighbor_right
    
        pretrained_cfg = infer_and_assert_hydra_config(pretrained_cfg)
    
        pretrained_cfg.TEST.BATCH_SIZE.effective_batch_size = cfg.TEST.BATCH_SIZE.effective_batch_size
    
        cfg = pretrained_cfg
        
    elif cfg.MODE in ["finetune"]:
        load_from = cfg.PRETRAINED_LOAD_FROM
        ckpt_root = cfg.PRETRAINED_CKPT_PATH
    
        with open(os.path.join(ckpt_root, load_from, "config.json"), "r") as fopen:
            pretrained_cfg = json.load(fopen)
            pretrained_cfg = easydict.EasyDict(pretrained_cfg)
    
        # override configuration of pre-trained model
        cfg.MODEL = pretrained_cfg.MODEL
        # set to use contextual relation network
        cfg.MODEL.contextual_relation_network.enabled = True
    
        # override neighbor size of an input sequence of shots
        sampling = pretrained_cfg.LOSS.sampling_method.name
        # cfg.LOSS.sampling_method.params["sbd"][
        #     "neighbor_size"
        # ] = pretrained_cfg.LOSS.sampling_method.params[sampling]["neighbor_size"]
        if sampling == "asymmetric":
            cfg.LOSS.sampling_method.params[sampling][
                "neighbor_left"
            ] = pretrained_cfg.LOSS.sampling_method.params[sampling]["neighbor_left"]
            cfg.LOSS.sampling_method.params[sampling][
                "neighbor_right"
            ] = pretrained_cfg.LOSS.sampling_method.params[sampling]["neighbor_right"]            
        else:
            cfg.LOSS.sampling_method.params["sbd"][
                "neighbor_size"
            ] = pretrained_cfg.LOSS.sampling_method.params[sampling]["neighbor_size"]
            
    
    return cfg


def init_data_loader(cfg, mode, is_train, is_test):

    if is_train:
        batch_size = cfg.TRAIN.BATCH_SIZE.batch_size_per_proc
    else:
        batch_size = cfg.TEST.BATCH_SIZE.batch_size_per_proc

    data_loader = torch.utils.data.DataLoader(
        dataset = get_dataset(cfg, mode=mode, is_train=is_train, is_test=is_test),
        batch_size = batch_size,
        num_workers = cfg.TRAIN.NUM_WORKERS,
        pin_memory=cfg.TRAIN.PIN_MEMORY,
        drop_last=is_train,
        shuffle=is_train,
        collate_fn=get_collate_fn(cfg),
    )

    if is_train:
        cfg.TRAIN.TRAIN_ITERS_PER_EPOCH = (
            len(data_loader.dataset) // cfg.TRAIN.BATCH_SIZE.effective_batch_size
        )

    return cfg, data_loader


def init_model(cfg):

    shot_encoder = get_shot_encoder(cfg)
    crn = get_contextual_relation_network(cfg)
    
    # loss = get_loss(cfg)
    # if "LOAD_FROM" in cfg and len(cfg.LOAD_FROM) > 0:
    #     model = PretrainingWrapper.load_from_checkpoint(
    #         cfg=cfg,
    #         shot_encoder=shot_encoder,
    #         loss=loss,
    #         crn=crn,
    #         checkpoint_path=os.path.join(cfg.CKPT_PATH, cfg.LOAD_FROM, "model.ckpt"),
    #         strict=False,
    #     )
    # else:
    #     model = PretrainingWrapper(cfg, shot_encoder, loss, crn)

    if cfg.MODE in ["pretrain", "extract"]:
        loss = get_loss(cfg)
        
        if "LOAD_FROM" in cfg and len(cfg.LOAD_FROM) > 0:
            model = PretrainingWrapper.load_from_checkpoint(
                cfg=cfg,
                shot_encoder=shot_encoder,
                loss=loss,
                crn=crn,
                checkpoint_path=os.path.join(cfg.CKPT_PATH, cfg.LOAD_FROM, "model.ckpt"),
                strict=False,
            )
        else:
            model = PretrainingWrapper(cfg, shot_encoder, loss, crn)
            
    elif cfg.MODE in ["finetune"]:
        if "LOAD_FROM" in cfg and len(cfg.LOAD_FROM) > 0:
            model = FinetuningWrapper.load_from_checkpoint(
                cfg=cfg,
                shot_encoder=shot_encoder,
                crn=crn,
                checkpoint_path=os.path.join(cfg.CKPT_PATH, cfg.LOAD_FROM, "model.ckpt"),
                strict=False,
            )
        elif "FINETUNED_LOAD_FROM" in cfg and len(cfg.PRETRAINED_LOAD_FROM) > 0:
            model = FinetuningWrapper.load_from_checkpoint(
                cfg=cfg,
                shot_encoder=shot_encoder,
                crn=crn,
                checkpoint_path=cfg.FINETUNED_LOAD_FROM,
                strict=False,
            )    
        elif "PRETRAINED_LOAD_FROM" in cfg and len(cfg.PRETRAINED_LOAD_FROM) > 0:
            model = FinetuningWrapper.load_from_checkpoint(
                cfg=cfg,
                shot_encoder=shot_encoder,
                crn=crn,
                checkpoint_path=os.path.join(cfg.PRETRAINED_CKPT_PATH, cfg.PRETRAINED_LOAD_FROM, "model.ckpt"),
                strict=False,
            )
        else:
            model = FinetuningWrapper(cfg, shot_encoder, crn)

    # logging.info(f"MODEL: {model}")

    return cfg, model

def init_trainer(cfg):

    logger = None
    callbacks = []

    logs_path = os.path.join(cfg.LOG_PATH, cfg.EXPR_NAME)
    os.makedirs(logs_path, exist_ok=True)
    logger = pl.loggers.TensorBoardLogger(logs_path, version=0)

    ckpt_path = os.path.join(cfg.CKPT_PATH, cfg.EXPR_NAME)
    os.makedirs(ckpt_path, exist_ok=True)

    lr_monitor_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
    callbacks.append(lr_monitor_callback)    
    
    # if cfg.MODE == "pretrain":
    if cfg.MODE in ["pretrain", "extract"]:
        # logs_path = os.path.join(cfg.LOG_PATH, cfg.EXPR_NAME)
        # os.makedirs(logs_path, exist_ok=True)
        # logger = pl.loggers.TensorBoardLogger(logs_path, version=0)

        # ckpt_path = os.path.join(cfg.CKPT_PATH, cfg.EXPR_NAME)
        # os.makedirs(ckpt_path, exist_ok=True)

        model_ckpt_callback = pl.callbacks.ModelCheckpoint(
            dirpath=ckpt_path, monitor=None, filename="model-{epoch:02d}", save_top_k=-1,
        )
        callbacks.append(model_ckpt_callback)
        
        # lr_monitor_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
        # callbacks.append(lr_monitor_callback)

        ### https://lightning.ai/docs/pytorch/stable/common/trainer.html
        trainer = pl.Trainer(
            **cfg.TRAINER,
            callbacks=callbacks, 
            logger=logger, 
            strategy=DDPStrategy(find_unused_parameters=False)
        )

    elif cfg.MODE == "finetune":
        # logs_path = os.path.join(cfg.LOG_PATH, cfg.EXPR_NAME)
        # os.makedirs(logs_path, exist_ok=True)
        # logger = pl.loggers.TensorBoardLogger(logs_path, version=0)

        # ckpt_path = os.path.join(cfg.CKPT_PATH, cfg.EXPR_NAME)
        # os.makedirs(ckpt_path, exist_ok=True)

        model_ckpt_callback = pl.callbacks.ModelCheckpoint(
            dirpath=ckpt_path, monitor="sbd_test/ap", mode='max', filename="model",
        )
        callbacks.append(model_ckpt_callback)
        
        # lr_monitor_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
        # callbacks.append(lr_monitor_callback)

        trainer = pl.Trainer(**cfg.TRAINER, callbacks=callbacks, logger=logger)

    else:
        raise NotImplementedError

    return cfg, trainer