from copy import deepcopy
from glob import glob
import json
import os
from pathlib import Path
import random
from typing import Any, Dict
import gym
import numpy as np
from omegaconf import DictConfig
from temporal_task_planner.utils.gen_sess_utils import load_session_json
import wandb
import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
from temporal_task_planner.trainer.early_stopping import EarlyStopping

# from temporal_task_planner.trainer.logger import Logger
from temporal_task_planner.utils.datasetpytorch_utils import (
    get_session_pairs,
    get_session_path,
)


class Evaluator:
    """Eval class that provides rollout support for
    any Transformer Task Planner model
    this has the same parameters as the Learner to 
    initialize the correct directory for rollouts
    session_split_filepath indicates the chkpt folder
    but we actually we test_split_filepath for rollouts
    NOTE:
    - hydra recursively instantiates all modules in learner
    - optimizer and scheduler are partial modules as they're dependent on other modules
        (namely model.parameters(), optimizer respectively)
    """

    def __init__(
        self,
        config: DictConfig,
        dataset_partial: Any,  # torch.utils.data.Dataset,
        model: PreTrainedModel,
        criterion: torch.nn.Module,
        optimizer_partial: Any,  # ModuleWrapper for torch.optim
        scheduler_partial: Any,  # ModuleWrapper for torch.optim.lr_scheduler
        logger_partial: Any,
    ) -> None:
        # fixed while training
        self.config = config
        self.session_split_filepath = config.session_split_filepath
        if config.test_split_filepath == "":
            self.test_split_filepath = config.session_split_filepath
        else:
            self.test_split_filepath = config.test_split_filepath
        self.session_paths = load_session_json(
            Path(self.config.pkg_root, self.test_split_filepath + '.json').as_posix()
        )
        name = 'test'
        setattr(self, f"max_{name}_evals", int(config[f"max_{name}_evals"]))
        session_paths = {}
        subset_session_paths = {}
        for preference_name in self.session_paths[name]:
            session_paths[preference_name] = []
            subset_session_paths[preference_name] = []
            for session_name in self.session_paths[name][preference_name]:
                session_paths[preference_name].append(Path(self.config.pkg_root, session_name).as_posix())
            subset_session_paths[preference_name] = session_paths[preference_name][:self.config[f"max_{name}_evals"]]
        
        setattr(self, f"{name}_logger", logger_partial(session_paths=subset_session_paths, name=name))

        self.chkpt_path = (
            f"{self.config['chkpt_name']}.pt"
        )
        self.fix_seed()
        # changes while training
        self.model = model
        self.model.to(device=self.config["device"])

    def fix_seed(self):
        seed = self.config["seed"]
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.use_deterministic_algorithms(True)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            # torch.backends.cudnn.deterministic = True
            # torch.backends.cudnn.benchmark = False
        return

    def load_chkpt(self) -> None:
        assert os.path.exists(self.chkpt_path), "checkpoint does not exist~!"
        checkpoint = torch.load(self.chkpt_path)  # wandb.restore(self.chkpt_path)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        # self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        # self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        self.epoch = checkpoint["epoch"]
        self.loss = checkpoint["loss"]
        self.step = checkpoint["step"]
        # self.example_ct = checkpoint["example_ct"]
        print(f"Loaded model={self.model}") #, optimizer={self.optimizer}")
        print(
            f"epoch={self.epoch}, step={self.step}"
        )  # , example_ct={self.example_ct}")
        return

    def log_metrics(self, to_rollout: bool) -> None:
        print(f"to_rollout={to_rollout}")
        metrics = {}
        # metrics.update(self.train_logger(self.model, self.step, to_rollout))
        # metrics.update(self.val_logger(self.model, self.step, to_rollout))
        metrics.update(self.test_logger(self.model, self.step, to_rollout))
        # metrics.update({"LearningRate": self.scheduler.get_last_lr()[0]})
        # metrics.update({"ExampleCounts": self.example_ct})
        print(metrics) # json.dumps(metrics, indent=4))
        wandb.log(metrics, step=self.step)
        return metrics

    def rollout(self, main_record_file='summary'):
        self.load_chkpt()
        # if self.config['log_per_session_metrics']:
        #     if not os.path.exists('per_session_metrics'):
        #         os.mkdir('per_session_metrics')
        to_rollout = True # if self.step % self.config["rollout_interval"] == (self.config["rollout_interval"] - 1) else False
        metrics = self.log_metrics(to_rollout)
        with open(f"{main_record_file}", 'w') as f:
            json.dump({self.session_split_filepath.split('/')[-1]: metrics}, f)
        # with open(f"{main_record_file}", 'w') as f:
        #     f.write(json.dumps({self.session_split_filepath.split('/')[-1]: metrics}))
        return

if __name__ == "__main__":
    import hydra

    @hydra.main(config_path="../../../config", config_name="learner")
    def main(cfg):
        cfg.pkg_root = hydra.utils.get_original_cwd()
        evaluator = hydra.utils.instantiate(cfg.learner)
        with wandb.init(
            project='rollout_' + cfg.wandb.project,
            entity=cfg.wandb.entity,
            config=cfg,
            save_code=False,
            name=construct_run_name(cfg),
            resume=True,
        ) as run:
            evaluator.rollout()

    main()
