import pyrootutils
root = pyrootutils.setup_root(
    search_from=__file__,
    indicator=[".git"],
    pythonpath=True,
    dotenv=True,
)

import os
from tqdm import tqdm
import pandas as pd
import torch
from omegaconf import DictConfig, OmegaConf
import hydra
from hydra.utils import instantiate
from lightning.pytorch import seed_everything
from lightning.pytorch.utilities import rank_zero_only

# from ignite.metrics import PSNR
# from torcheval.metrics import PeakSignalNoiseRatio as PSNR
# from torcheval.metrics import Mean
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchmetrics.image.psnr import PeakSignalNoiseRatio as PSNR

from src.utils.animate import load_config_and_model


def zeroto1(image, mean=128, std=64):
    image = image * std + mean
    image = image / 255.0
    return torch.clamp(image, 0, 1)

# def m1to1(image):
#     image = zeroto1(image) * 2 - 1
#     return torch.clamp(image, -1, 1)

@hydra.main(version_base=None, config_path='../config', config_name='eval')
def main(config: DictConfig):

    print('load model: ', config.model_path)
    config_model, model = load_config_and_model(config.model_path, finetune=True)
    OmegaConf.set_struct(config, None)
    OmegaConf.set_struct(config_model, None)

    config = OmegaConf.merge(config_model, config)

    # trainer = instantiate(config.trainer)
    # if trainer.global_rank == 0:
    #     print(OmegaConf.to_yaml(config))

    if config.seed is not None:
        seed_everything(config.seed, workers=True)

    dataset = instantiate(config.data.dataset, split='test')

    burnin = 4
    delta_t = 4  # actually this specifies deltaPSNR at t=8
    num_actions = 15

    # df = pd.DataFrame(columns=['env', 'method', 'metric', 'value'])
    df = []
    device = config.device
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        for variant in tqdm(dataset.variants):
            psnr = PSNR(data_range=1.0).to(device)
            psnr_t = PSNR(data_range=1.0).to(device)
            psnr_tr = PSNR(data_range=1.0).to(device)
            lpips = LPIPS(normalize=True).to(device)  # images are in [0, 1]
            eval_loader = instantiate(config.data, dataset={'split': 'eval', 'env': variant})
            for batch in eval_loader:
                (fullseq, actions, _) = model.preprocess(batch)
                fullseq = fullseq.to(device)
                actions = actions.to(device)
                _, (target, pred) = model.loss_finetune((fullseq, actions, None))
                psnr.update(zeroto1(pred), 
                            zeroto1(target))
                lpips.update(zeroto1(pred).flatten(0, 1), 
                             zeroto1(target).flatten(0, 1))
                
                random_actions = torch.randint_like(actions[burnin:], 0, num_actions)
                actions[burnin:] = random_actions
                _, (_, random_pred) = model.loss_finetune((fullseq, actions, None))
                psnr_t.update(zeroto1(pred[delta_t]), 
                              zeroto1(target[delta_t]))
                psnr_tr.update(zeroto1(random_pred[delta_t]), 
                               zeroto1(target[delta_t]))

            df_args = {'env': variant, 'method': 'calm'}
            df.append(dict(**df_args, metric='PSNR', value=psnr.compute().item(), ))
            delta = psnr_t.compute() - psnr_tr.compute()
            df.append(dict(**df_args, metric='deltaPSNR', value=delta.item(), ))
            df.append(dict(**df_args, metric='LPIPS', value=lpips.compute().item(), ))
            

    
    save_dir = config.save_dir
    # create {save_dir} dicrectory
    os.makedirs(save_dir, exist_ok=True)
    df = pd.DataFrame(df)
    df.to_csv(os.path.join(save_dir, 'calm.csv'), index=False)

    return model


if __name__ == '__main__':
    main()