from copy import deepcopy
from glob import glob
import json
import os
from pathlib import Path
import random
from typing import Any, Dict
import gym
import numpy as np
from omegaconf import DictConfig
from temporal_task_planner.utils.gen_sess_utils import load_session_json
import wandb
import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedModel
from temporal_task_planner.trainer.early_stopping import EarlyStopping

# from temporal_task_planner.trainer.logger import Logger
from temporal_task_planner.utils.datasetpytorch_utils import (
    get_session_pairs,
    get_session_path,
)


class Learner:
    """Learner class that provides training support for
    any Transformer Task Planner model
    NOTE:
    - wandb calls are in `train` and `train_batch`
    - hydra recursively instantiates all modules in learner
    - optimizer and scheduler are partial modules as they're dependent on other modules
        (namely model.parameters(), optimizer respectively)
    """

    def __init__(
        self,
        config: DictConfig,
        dataset_partial: Any,  # torch.utils.data.Dataset,
        model: PreTrainedModel,
        criterion: torch.nn.Module,
        optimizer_partial: Any,  # ModuleWrapper for torch.optim
        scheduler_partial: Any,  # ModuleWrapper for torch.optim.lr_scheduler
        logger_partial: Any,
    ) -> None:
        # fixed while training
        self.config = config
        self.session_split_filepath = config.session_split_filepath
        self.session_paths = load_session_json(
            Path(self.config.pkg_root, self.session_split_filepath + '.json').as_posix()
        )
        # self.dataset_partial = dataset_partial
        assert "train" in self.session_paths.keys()
        assert "val" in self.session_paths.keys()
        assert "test" in self.session_paths.keys()
        for name in ["train", "val", "test"]:
            session_paths = {}
            subset_session_paths = {}
            for preference_name in self.session_paths[name]:
                session_paths[preference_name] = []
                subset_session_paths[preference_name] = []
                for session_name in self.session_paths[name][preference_name]:
                    session_paths[preference_name].append(Path(self.config.pkg_root, session_name).as_posix())
                subset_session_paths[preference_name] = session_paths[preference_name][:self.config[f"max_{name}_evals"]]
            
            setattr(
                self,
                f"{name}_dataset",
                dataset_partial(session_paths=session_paths),
            )
            # setattr(
            #     self,
            #     f"{name}_logger",
            #     logger_partial(
            #         session_paths=subset_session_paths, 
            #         name=name,
            #     ), 
            # )
        self.test_logger = logger_partial(session_paths=subset_session_paths, name=name)
        if not os.path.exists(Path('rollouts').as_posix()):
            os.mkdir(Path('rollouts').as_posix())
        if not os.path.exists(Path('rollouts', name).as_posix()):
            os.mkdir(Path('rollouts', name).as_posix())
        self.chkpt_path = (
            f"{self.config['chkpt_name']}.pt"
        )
        self.fix_seed()
        if self.config["batch_size"] == 0:
            self.set_adaptative_batch()
        else:
            self.batch_size = self.config["batch_size"]
        assert self.batch_size is not None or self.batch_size != 0
        self.train_loader = self.get_dataloader("train", self.batch_size, shuffle=True)
        self.max_train_evals = int(config["max_train_evals"])
        self.max_val_evals = int(config["max_val_evals"])
        # changes while training
        self.model = model
        self.model.to(device=self.config["device"])
        self.criterion = criterion
        self.optimizer = optimizer_partial(params=model.parameters())
        self.scheduler = scheduler_partial(optimizer=self.optimizer)
        self.early_stopping = EarlyStopping(
            patience=self.config["patience"],
            verbose=True,
            path=self.chkpt_path,
        )

    def fix_seed(self):
        seed = self.config["seed"]
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.use_deterministic_algorithms(True)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            # torch.backends.cudnn.deterministic = True
            # torch.backends.cudnn.benchmark = False
        return

    # def set_adaptative_batch(self):
    #     self.batch_size = self.config["num_targets_per_step"] // (
    #         self.config["context_history"] + 1
    #     )

    # def init_all_loggers(self) -> None:
    #     self.train_logger = Logger(
    #         self.env,
    #         self.session_paths["train"],
    #         self.config,
    #         self.model,
    #         self.criterion,
    #         name="Train",
    #         pick_only=self.config["pick_only"],
    #         max_evals=self.max_train_evals,
    #     )
    #     self.val_logger = Logger(
    #         self.env,
    #         self.session_paths["val"],
    #         self.config,
    #         self.model,
    #         self.criterion,
    #         name="Val",
    #         pick_only=self.config["pick_only"],
    #         max_evals=self.max_val_evals,
    #     )

    def load_chkpt(self) -> None:
        assert os.path.exists(self.chkpt_path), "checkpoint does not exist~!"
        checkpoint = torch.load(self.chkpt_path)  # wandb.restore(self.chkpt_path)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        self.epoch = checkpoint["epoch"]
        self.loss = checkpoint["loss"]
        self.step = checkpoint["step"]
        # self.example_ct = checkpoint["example_ct"]
        print(f"Loaded model={self.model}, optimizer={self.optimizer}")
        print(
            f"epoch={self.epoch}, step={self.step}"
        )  # , example_ct={self.example_ct}")
        return

    def init_training(self) -> None:
        if self.config["load_chkpt"] and os.path.exists(self.chkpt_path):
            print("Loading checkpoint!")
            self.load_chkpt()
        else:
            self.epoch = 0
            self.step = 0
            self.example_ct = 0
            print(
                f"No checkpoint found. Starting with epoch={self.epoch}, step={self.step}, example_ct={self.example_ct}"
            )
        return

    # def init_dataset(self, dataset_type="train") -> PromptSituationDataset:
    #     data = PromptSituationDataset(
    #         session_pairs=get_session_pairs(
    #             Path(self.config["pkg_root"], self.config["session_path"]).as_posix()
    #         ),
    #         pick_only=self.config["pick_only"],
    #     )
    #     return data

    def get_dataloader(
        self, dataset_type: str = "train", batch_size=1, shuffle=False
    ) -> DataLoader:
        # data = self.init_dataset(dataset_type)
        # data = self.dataset
        data = getattr(self, f"{dataset_type}_dataset")
        print(f"{dataset_type} data sequences: {len(data)}")
        loader = DataLoader(
            data,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=self.config['num_workers'],
            collate_fn=self.config["padding"],
        )
        return loader

    def log_metrics(self, to_rollout: bool) -> None:
        print(f"to_rollout={to_rollout}")
        metrics = {}
        # metrics.update(self.train_logger(self.model, self.step, to_rollout))
        # metrics.update(self.val_logger(self.model, self.step, to_rollout))
        metrics.update(self.test_logger(self.model, self.step, to_rollout))
        metrics.update({"LearningRate": self.scheduler.get_last_lr()[0]})
        # metrics.update({"ExampleCounts": self.example_ct})
        print(metrics) # json.dumps(metrics, indent=4))
        wandb.log(metrics, step=self.step)

    def train_batch(self, data):
        if self.config['to_train']:
            self.model.train()
            self.optimizer.zero_grad()
            input, target = data
            # assert np.any(np.isnan(target['action_instance']))
            out = self.model(**input, device=self.config["device"])
            # assert (
            #     out["pick"].shape[0] == torch.cat(target["action_instance"], dim=0).shape[0]
            # ), "out and target size do not match"
            loss = self.criterion(out, target)
            
            if self.step % self.config["logging_interval"] == 0:
                wandb.log({"RealTrainLoss": loss.item()}, self.step)
            loss.backward()
            self.optimizer.step()

        to_rollout = True if self.step % self.config["rollout_interval"] == (self.config["rollout_interval"] - 1) else False
        if self.step % self.config["logging_interval"] == 0:
            self.log_metrics(to_rollout)
            ## QUICK DEBUG Acc, y_true and y_pred
            # pred = torch.max(out["pick"], dim=1)[1].cpu()
            # correct = (pred == torch.cat(target["action_instance"], dim=0)).sum()
            # total = torch.cat(target["action_instance"], dim=0).shape[0]
            # accuracy = correct / total
            # print('Step=', self.step, 'Accuracy=', accuracy)
            # wandb.log({"Accuracy/Train/Symbol": accuracy}, step=self.step)
            # val_loss = np.mean(self.val_logger.main_log["Loss"])
            self.scheduler.step()
            # metrics = {}
            # metrics.update(self.logger(self.model, self.step, to_rollout))
            # metrics.update({"LearningRate": self.scheduler.get_last_lr()[0]})
            # wandb.log(metrics, step=self.step)
            #         # data = self.val_logger.log_true_pred_tokens(self.step)
            #         # self.wandb_table_true_pred.add_data(*data)
            #         val_loss = np.mean(self.val_logger.main_log["Loss"])
            save_dict = {
                # "example_ct": self.example_ct + input["action_masks"].sum().item(),
                "step": self.step + 1,
                "epoch": self.epoch,
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "scheduler_state_dict": self.scheduler.state_dict(),
                "loss": loss,
            }
            saved_checkpoint = self.early_stopping(
                loss.item(), #val_loss,
                save_dict=save_dict,
            )
            torch.save(save_dict, self.chkpt_path)
            wandb.save(self.chkpt_path)

            if self.early_stopping.early_stop:
                print("Early stopping")
                # wandb.log({"Log": self.wandb_table_true_pred})
                return self.model.state_dict()

        self.step += 1
        return

    def train(self):
        self.init_training()
        while self.step <= self.config["max_steps"]:
            for data in self.train_loader:
                self.train_batch(data)
            self.epoch += 1
        # wandb.log({"Sample True Pred": self.wandb_table_true_pred})
        return self.model.state_dict()

if __name__ == "__main__":
    import hydra

    @hydra.main(config_path="../../../config", config_name="learner")
    def main(cfg):
        cfg.pkg_root = hydra.utils.get_original_cwd()
        learner = hydra.utils.instantiate(cfg.learner)
        learner.init_wandb(cfg)
        with wandb.init(
            project=cfg.wandb.project,
            entity=cfg.wandb.entity,
            config=cfg,
            save_code=False,
            name=construct_run_name(cfg),
            resume=True,
        ) as run:
            learner.train()

    main()
