import os
import torch
import numpy as np
from typing import List

import hydra
from omegaconf import DictConfig
from hydra.utils import instantiate

from pytorch_lightning import (
    LightningDataModule,
    LightningModule,
    Trainer,
    seed_everything,
)
from pytorch_lightning.loggers import Logger

from src import utils

log = utils.get_logger(__name__)


def test(cfg: DictConfig) -> None:
    """Contains minimal example of the testing pipeline. Evaluates given checkpoint on a testset.

    Args:
        cfg (DictConfig): Configuration composed by Hydra.

    Returns:
        None
    """

    cfg.ckpt_path = cfg.ckpt_path + cfg.ckpt

    # Set seed for random number generators in pytorch, numpy and python.random
    if cfg.get("seed"):
        seed_everything(cfg.seed, workers=True)

    # Convert relative ckpt path to absolute path if necessary
    if not os.path.isabs(cfg.ckpt_path):
        cfg.ckpt_path = os.path.join(
            hydra.utils.get_original_cwd(), cfg.ckpt_path
        )

    # Init lightning datamodule
    log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
    datamodule: LightningDataModule = instantiate(cfg.datamodule)

    # Initialize the model
    if cfg.datamodule._target_ == "src.datamodule.letters_dataloader.LettersDatamodule":
        if cfg.datamodule.conditional:
            model = instantiate(
                cfg.model, num_conditions=datamodule.train_dataset.num_conditions
            )
        else:
            model = instantiate(cfg.model)
    elif (
        cfg.datamodule._target_ == "src.datamodule.trellis_dataloader.TrellisDatamodule"
    ):
        if cfg.datamodule.num_components is not None:
            model = instantiate(
                cfg.model,
                dim=cfg.datamodule.num_components,
                pca=datamodule.train_dataset.pca,
                pca_space=True,
            )
        else:
            model = instantiate(cfg.model)
    else:
        model = instantiate(cfg.model)

    model_dict = model.state_dict()
    PATH = cfg.ckpt_path
    checkpoint = torch.load(PATH)
    param_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_dict}
    model_dict.update(param_dict)
    model.load_state_dict(model_dict)
    
    # Init lightning loggers (this can be)
    logger: List[Logger] = []
    if "logger" in cfg:
        for _, lg_conf in cfg.logger.items():
            if "_target_" in lg_conf:
                log.info(f"Instantiating logger <{lg_conf._target_}>")
                logger.append(instantiate(lg_conf))

    # Init lightning trainer
    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
    trainer: Trainer = instantiate(cfg.trainer, logger=logger)

    # Log hyperparameters
    trainer.logger.log_hyperparams({"ckpt_path": cfg.ckpt_path})
    
    if cfg.datamodule._target_ == "src.datamodule.letters_dataloader.LettersDatamodule":
        # Start prediction
        log.info("Starting prediction on train!")
        preds_train = trainer.predict(model, datamodule.train_dataloader())
        #print(preds_train)
        print(len(preds_train))
        
        idcs_train = [2, 11, 23, 31, 45, 52]
        samples_train, trajs_train = [], []
        for i in idcs_train:
            p = preds_train[i]
            traj = p[1]
            idx = p[0]
            source = p[2]
            pred = p[3]
            true = p[4]
            samples_train.append((idx, source[:, :, :], true[:, :, :]))
            trajs_train.append(traj[:, :, :, :])
        
        model.plot(
            trajs_train,
            samples_train,
            num_row=6,
            num_step=3,
            tag="fm_letters_train",
            shuffle=True,
        )
        
        #log.info("Starting prediction on val!")
        #preds_val = trainer.predict(model, datamodule.val_dataloader())
        
        log.info("Starting prediction on test!")
        model.predict_count = 0 
        preds_test = trainer.predict(model, datamodule.test_dataloader())
  
        samples_test, trajs_test = [], []
        for p in preds_test[:6]:
            traj = p[1]
            idx = p[0]
            source = p[2]
            pred = p[3]
            true = p[4]
            samples_test.append((idx, source[:, :, :], true[:, :, :]))
            trajs_test.append(traj[:, :, :, :])
        
        model.plot(
            trajs_test,
            samples_test,
            num_row=6,
            num_step=3,
            tag="fm_letters_test",
            shuffle=True,
        )
        
        # plot train and test together
        trajs = trajs_train[3:5] + trajs_test[3:5]
        samples = samples_train[3:5] + samples_test[3:5]
        model.plot(trajs, samples, num_row=4, num_step=3, tag="fm_letters_main", shuffle=True)
        
    # Start validation
    log.info("Starting validating!")
    #trainer.validate(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
    trainer.validate(model=model, datamodule=datamodule)

    # Start testing
    log.info("Starting testing!")
    #trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
    trainer.test(model=model, datamodule=datamodule)