from abc import ABC
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple
from ast import literal_eval
import gym
import json
import numpy as np
from requests import session
import numpy.typing as npt
from copy import deepcopy
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix
import textdistance
import wandb
import itertools
from transformers import PreTrainedModel
import pandas as pd

from temporal_task_planner.constants.gen_sess_config.lookup import (
    category_vocab,
    special_instance_vocab,
    category_dict,
)
from temporal_task_planner.policy.expert_policy import ExpertPickOnlyPreferencePolicy
from temporal_task_planner.policy.learned_policy import (
    LearnedPolicy,
    PromptSituationLearnedPolicy,
    SingleSessionLearnedPolicy,
)
from temporal_task_planner.rollout import session_rollout
from temporal_task_planner.utils.data_structure_utils import get_actions_from_session
from temporal_task_planner.utils.extract_from_session import (
    get_placed_utensils_in_useraction,
)
from temporal_task_planner.utils.datasetpytorch_utils import get_temporal_context
from temporal_task_planner.utils.gen_sess_utils import load_session_json

# Fix this!
from temporal_task_planner.trainer.dataset import (
    DishwasherArrangeSavedSession,
    PromptSituationDataset,
    PromptSituationPickPlace,
    SessionPreferenceDataset,
    pad_fn,
    prompt_pad_fn,
    preference_classifier_pad_fn
)


class Logger(ABC):
    def __init__(
        self,
        env: Any,
        config: Dict,
        model: PreTrainedModel,
        criterion: torch.nn.Module,
        pick_only: bool = True,
        max_evals: int = 100,
        savefolder: str = "rollouts",
    ):
        self.env = env
        self.config = config
        self.device = self.config["device"]
        self.model = model
        self.criterion = criterion
        self.pick_only = pick_only
        self.max_evals = max_evals
        self.savefolder = savefolder
        self.categ_attr = []
        self.cont_attr = []
        self.input_attr = self.categ_attr + self.cont_attr
        self.rollout_metrics = ["TPE", "BPE", "PE", "SPL", "LC"]

    def __call__(
        self,
        model: torch.nn.Module,
        step: int,
        to_rollout: bool,
        *args: Any,
        **kwds: Any,
    ) -> Any:
        raise NotImplementedError

    def get_expert_rollout_metrics(self, session_path: str):
        data = load_session_json(session_path)
        return {
            "placed_utensils": get_placed_utensils_in_useraction(
                data, useraction_idx=-1
            ),
            "actions": get_actions_from_session(data),
        }

    def compute_all_rollout_metrics(
        self, expert_rollout_info: Dict, learned_rollout_info: Dict
    ) -> Dict:
        """Calculates metrics for rollout of the learned policy wrt
        what an expert policy would have done in that scenario.
        """
        # Packing efficiency
        expert_placed_utensils = expert_rollout_info["placed_utensils"]
        learned_placed_utensils = learned_rollout_info["placed_utensils"]

        def get_packing_metrics(name: str) -> float:
            num_correct_category = self.env.env_utils.count_by_category_preference(
                learned_placed_utensils[name],
                getattr(self.env.env_utils.preference, f"category_order_{name}"),
            )
            return num_correct_category / max(len(expert_placed_utensils[name]), num_correct_category)

        self.main_log["TPE"].append(get_packing_metrics(name="top_rack"))
        self.main_log["BPE"].append(get_packing_metrics(name="bottom_rack"))
        # self.main_log["SPE"].append(get_packing_metrics(name="sink")
        self.main_log["PE"].append(
            (self.main_log["TPE"][-1] + self.main_log["BPE"][-1]) / 2.0
        )

        # Time efficiency
        expert_actions = expert_rollout_info["actions"]
        learned_actions = learned_rollout_info["actions"]
        success = self.main_log["PE"][-1]
        self.main_log["SPL"].append(
            success * len(expert_actions) / max(len(learned_actions), len(expert_actions))
        )

        # Levenshtein Distance for the pick category order
        expert_category_order = [
            action.pick_instance.category_name for action in expert_actions
        ]
        learned_category_order = [
            action.pick_instance.category_name for action in learned_actions
        ]
        self.main_log["LC"].append(
            textdistance.levenshtein(expert_category_order, learned_category_order)
            / max(len(expert_category_order), len(learned_category_order))
        )
        return {
            'TPE': self.main_log["TPE"][-1],
            'BPE': self.main_log["BPE"][-1],
            'PE': self.main_log["PE"][-1],
            'SPL': self.main_log["SPL"][-1],
            'LC': self.main_log["LC"][-1]
        }

    def init_sess_log(self) -> Dict:
        """Aggregates the target and predicted lists for
        sklearn metrics
        """
        sess_logs = dict(
            losses=[],
            y_pred=[],
            y_true=[],
            x_pred={key: [] for key in self.input_attr},
            x_true={key: [] for key in self.input_attr},
        )
        return sess_logs

    def init_main_log(self, to_rollout: bool = False):
        """
        Records Loss, Acc, etc per session
        """
        main_log = dict(
            Loss=[],
            Acc=dict(symbol=[], attribute={key: [] for key in self.categ_attr}),
            # Precision=dict(attribute={key: [] for key in self.categ_attr}),
            # Recall=dict(attribute={key: [] for key in self.categ_attr}),
            # AvgL2=dict(attribute={key: [] for key in self.cont_attr}),
        )
        if to_rollout:
            main_log.update(
                dict(
                    TPE=[],
                    BPE=[],
                    PE=[],
                    SPL=[],
                    LC=[],
                )
            )
        return main_log

    def get_attr_based_predictions(self, inputs: Dict, targets: torch.Tensor) -> Dict:
        """Processes the attribute dictionary for input instances"""
        # if self.pick_only:
        #     attr_type = ["pick_category", "place_category"]
        #     targets_stacked = torch.stack(targets, dim=1)    
        #     # even indexed targets are pick

        #     # odd indexed targets are place 
        targets = targets.reshape(len(inputs["timestep"]), -1)
        _x = {}
        for attr_type in ["category_token"]:  # timestep, pose
            _x[attr_type] = []
            for i in range(len(targets)):
                _x[attr_type] += deepcopy(
                    inputs[attr_type][i, targets[i, :]].cpu().detach().numpy().tolist()
                )

        # # category : 3dBB
        # for i in range(len(targets)):
        #     _x["category-3dBB"] += deepcopy(
        #         inputs["category"][i, targets[i, :]].cpu().detach().numpy().tolist()
        #     )

        return _x

    def accumulate(
        self,
        loss: int,
        y_true: npt.ArrayLike,
        y_pred: npt.ArrayLike,
        x_true: Dict[str, List],
        x_pred: Dict[str, List],
    ) -> None:
        """
        Args:
            loss, inputs, out, targets : torch Tensor
        Return:
            self.sess_log : dict of (dict of) list (No numpy, no torch)
        """
        self.sess_log["losses"].append(loss.item())
        self.sess_log["y_true"] += y_true.tolist()
        self.sess_log["y_pred"] += y_pred.tolist()
        for key in list(x_true.keys()):
            self.sess_log["x_true"][key] += x_true[key]
            self.sess_log["x_pred"][key] += x_pred[key]
        return self.sess_log

    def update_main_log(self) -> None:
        """Update loss, acc and other metrics of per step predictions"""
        self.main_log["Loss"].append(np.mean(self.sess_log["losses"]))
        self.main_log["Acc"]["symbol"].append(
            accuracy_score(self.sess_log["y_true"], self.sess_log["y_pred"])
        )
        for key in ["category_token"]:  # self.categ_attr:
            self.main_log["Acc"]["attribute"][key].append(
                accuracy_score(
                    self.sess_log["x_true"][key], self.sess_log["x_pred"][key]
                )
            )
            # self.main_log["Precision"]["attribute"][key].append(
            #     precision_score(
            #         self.sess_log["x_true"][key],
            #         self.sess_log["x_pred"][key],
            #         average="weighted",
            #         zero_division=0,
            #     )
            # )
            # self.main_log["Recall"]["attribute"][key].append(
            #     recall_score(
            #         self.sess_log["x_true"][key],
            #         self.sess_log["x_pred"][key],
            #         average="weighted",
            #         zero_division=0,
            #     )
            # )
        # for key in self.cont_attr:
        #     avgL2 = np.mean(
        #         np.linalg.norm(
        #             np.array(self.sess_log["x_true"][key])
        #             - np.array(self.sess_log["x_pred"][key]),
        #             axis=-1,
        #             ord=2,
        #         )
        #     )
        #     self.main_log["AvgL2"]["attribute"][key].append(avgL2)
        return

    def get_avg_log(self, to_rollout: bool = False) -> Dict:
        """
        Recording mean of logs per session
        TODO: add std value too for error bar.
        """
        metrics = {}
        metrics.update(
            {
                f"{self.name}/Loss": np.mean(self.main_log["Loss"]),
                f"{self.name}/Acc/symbol": np.mean(self.main_log["Acc"]["symbol"]),
            }
        )
        for attr in ["category_token"]:
            metrics.update(
                {
                    f"{self.name}/Acc/{attr}": np.mean(
                        self.main_log["Acc"]["attribute"][attr]
                    ),
                    #     f"{self.name}/Precision/{attr}": np.mean(
                    #         self.main_log["Precision"]["attribute"][attr]
                    #     ),
                    #     f"{self.name}/Recall/{attr}": np.mean(
                    #         self.main_log["Recall"]["attribute"][attr]
                    #     ),
                }
            )
        # for attr in self.cont_attr:
        #     metrics.update(
        #         {
        #             f"{self.name}/AvgL2/{attr}": np.mean(
        #                 self.main_log["AvgL2"]["attribute"][attr]
        #             )
        #         }
        #     )
        if to_rollout:
            for rollout_metric in self.rollout_metrics:
                metrics.update(
                    {
                        f"{self.name}/{rollout_metric}": np.mean(
                            self.main_log[f"{rollout_metric}"]
                        )
                    }
                )
        return metrics

    def print(self, metrics: Dict, step: int = None) -> Dict:
        """prints the metrics;
        adds a confusion matrix if it
        TODO: print to file"""
        if step is None:
            step = "final"
        print(
            step,
            [
                metrics[key]
                for key in [
                    f"{self.name}/Loss",
                    f"{self.name}/Acc/symbol",
                    f"{self.name}/Acc/category_token",
                ]
            ],
        )
        return metrics

    def compute_confusion_matrix(self) -> wandb.plot.confusion_matrix:
        """Returns a wandb confusion matrix for current sess_log true and pred seq"""
        return wandb.plot.confusion_matrix(
            probs=None,
            y_true=self.sess_log["x_true"]["category_token"],
            preds=self.sess_log["x_pred"]["category_token"],
            class_names=list(category_dict.values()),
        )

    def log_true_pred_tokens(self, step: int) -> List:
        """Processes the values of True and Predicted vocab tokens
        per step (along with loss at that step) for failure case analysis"""
        target_token_names = self.sess_log["x_true"]["category_token"]
        pred_token_names = self.sess_log["x_pred"]["category_token"]

        data = {
            "Step": step,
            "Loss": self.main_log["Loss"],
            "True": [
                category_vocab.index2word(target_token_names[i])
                for i in range(len(target_token_names))
            ],
            "Predicted": [
                category_vocab.index2word(pred_token_names[i])
                for i in range(len(pred_token_names))
            ],
        }
        return list(data.values())

class SingleModelLogger(Logger):
    def __init__(
        self,
        env: Any,  # gym.Env,
        session_paths: List[str],
        config: Dict,
        model: PreTrainedModel,
        criterion: torch.nn.Module,
        name: str,
        pick_only: bool = True,
        max_evals: int = 10,
        savefolder: str = "rollouts",
    ):
        super().__init__(env, config, model, criterion, pick_only, savefolder)
        self.session_paths = session_paths  # deepcopy(self.dataset.session_paths)
        if not os.path.exists(self.savefolder):
            os.mkdir(self.savefolder)

        self.expert_rollout_info = [
            self.get_expert_rollout_metrics(self.session_paths[session_id])
            for session_id in range(max_evals)
        ]
        #     self.env.reset(ref_sess=self.session_paths[session_id])
        #     self.expert_rollout_info.append(
        #         session_rollout(
        #             self.env,
        #             expert_policy,
        #             savepath=None
        #             # Path(self.savefolder, "expert_rollout.json").as_posix(),
        #         )
        #     )

        # init env with a session from the session paths
        # TODO: compute over all session paths
        self.context_history = self.config["context_history"]
        self.name = name
        if not self.pick_only:
            self.pickplacelogs = dict(
                PickAcc=[],
                PlaceAcc=[],
            )
        self.categ_attr = ["timestep", "category_token"]
        self.cont_attr = ["category-3dBB", "pose"]
        self.input_attr = self.categ_attr + self.cont_attr
        self.rollout_metrics = ["TPE", "BPE", "PE", "SPL", "LC"]
        self.main_log = None  # self.init_main_log()
        self.sess_log = None

    def __call__(
        self, model, step: int, to_rollout: bool, *args: Any, **kwds: Any
    ) -> Dict:
        """Calculates metrics per session and averages them"""
        print(f"Metric computation: to_rollout={to_rollout}")
        self.main_log = self.init_main_log(to_rollout)
        # New params of the Model, in eval mode
        self.model = model
        self.model.eval()
        if not os.path.exists("rollouts"):
            os.mkdir("rollouts")
        # Sample `max_evals` sessions from session_paths
        for session_id in range(self.max_evals):
            self.get_per_step_metrics(session_id)
            if to_rollout:
                savepath = (
                    f"rollouts/{self.name}_step-{step}_session_id-{session_id}.json"
                )
                self.get_rollout_metrics(session_id, savepath)
        # Format metrics and upload to wandb
        metrics = self.get_avg_log(to_rollout)
        self.print(metrics, step=step)
        return metrics

    @torch.no_grad()
    def get_per_step_metrics(self, session_id: int) -> None:
        """Inits the dataset class wrapper to process each session
        independently for the computing metrics"""
        self.sess_log = self.init_sess_log()
        assert session_id < len(self.session_paths), print(
            session_id, len(self.session_paths)
        )
        session_dataset = DishwasherArrangeSavedSession(
            session_paths=[self.session_paths[session_id]],
            context_history=self.context_history,
            num_sessions_limit=1,
            pick_only=self.pick_only,
        )
        dataloader = DataLoader(
            session_dataset, batch_size=len(session_dataset) + 1, collate_fn=pad_fn
        )
        for inputs, targets in dataloader:
            out = self.model(**inputs, device=self.device)
            loss = self.criterion(out, targets)
            y_true = (
                torch.cat(targets["action_instance"], dim=0).cpu().numpy()
            )  # .tolist()
            y_pred = (
                torch.argmax(out["pick"], dim=-1).reshape(-1).cpu().detach().numpy()
            )  # .tolist()
            x_true = self.get_attr_based_predictions(inputs, y_true)
            x_pred = self.get_attr_based_predictions(inputs, y_pred)
            self.accumulate(loss, y_true, y_pred, x_true, x_pred)
            if not self.pick_only:
                # pick-place acc metrics
                y_true = (
                    torch.stack(targets["action_instance"], dim=0).cpu().numpy()
                )  # .tolist()
                y_pred = (
                    torch.argmax(out["pick"], dim=-1)
                    .reshape(-1, 2)
                    .cpu()
                    .detach()
                    .numpy()
                )  # .tolist()
                self.pickplacelogs["PickAcc"] = accuracy_score(
                    y_true[:, 0], y_pred[:, 0]
                )
                self.pickplacelogs["PlaceAcc"] = accuracy_score(
                    y_true[:, 1], y_pred[:, 1]
                )
                print(self.pickplacelogs)
        self.update_main_log()

    def rollout_learned_policy(self, session_id: int, savepath: str) -> Dict:
        """Env is initialized to startframe of session.
         learned policy are rolled out on the env
        Returns : Dict
        """
        obs = self.env.reset(ref_sess=self.session_paths[session_id])
        learned_policy = SingleSessionLearnedPolicy(
            self.model,
            self.config["context_history"],
            self.config["device"],
            self.config["pick_only"],
        )
        learned_policy.reset()
        learned_rollout_info = session_rollout(obs, self.env, learned_policy, savepath)
        return learned_rollout_info

    def get_rollout_metrics(self, session_id: int, savepath: str) -> None:
        expert_rollout_info = self.expert_rollout_info[session_id]
        learned_rollout_info = self.rollout_learned_policy(session_id, savepath)
        self.compute_all_rollout_metrics(expert_rollout_info, learned_rollout_info)
        return
 

class DualModelLogger(Logger):
    def __init__(
        self,
        env: Any,
        session_paths: Dict, # List[Tuple[str, str]],
        config: Dict,
        model: PreTrainedModel,
        criterion: torch.nn.Module,
        name: str,
        pick_only: bool = True,
        max_evals: int = 10,
        savefolder: str = "rollouts",
    ) -> None:
        super().__init__(
            env, config, model, criterion, pick_only, max_evals, savefolder
        )
        self.session_paths = session_paths
        if isinstance(self.session_paths, str):
            self.session_paths = literal_eval(self.session_paths)
        # self.session_pairs = list(itertools.product(self.session_paths, self.session_paths))
        self.savepath = f"{savefolder}_{self.config['test_split_filepath'].split('/')[-1]}"
        self.name = name
        if self.savepath != '':
            if not os.path.exists(Path(self.savepath).as_posix()):
                os.mkdir(Path(self.savepath).as_posix())
            if not os.path.exists(Path(self.savepath, self.name).as_posix()):
                os.mkdir(Path(self.savepath, self.name).as_posix())
        self.categ_attr = ["category_token"]
        self.input_attr = self.categ_attr
        self.prompt = (
            None  # update in get_per_session_metrics for rollout learned policy
        )
        self.pick_only = self.config.pick_only
        self.experiment_record = []
    
    def init_main_log(self, to_rollout):
        main_log = super().init_main_log(to_rollout)
        main_log.update({
            'prompt': [],
            'situation': []
        })
        return main_log

    def __call__(self, model: torch.nn.Module, step: int, to_rollout: bool) -> Dict:
        self.main_log = self.init_main_log(to_rollout)
        self.model = model
        self.model.eval()
        self.get_per_session_metrics()
        if to_rollout:
            session_pairs = []
            for preference_folder, session_list in self.session_paths.items():
                # session_list = get_session_list(preference_folder, list_num_objects_per_rack, num_sessions_limit)
                # temporal_contexts = []
                # for session_path in session_list:
                    # temporal_context = get_temporal_context(session_path)
                session_pairs += list(itertools.product(session_list, session_list))
            for idx, session_pair in enumerate(session_pairs):        
                if self.savepath != '':
                    savepath = Path(self.savepath, self.name, f"pair_{idx}.json").as_posix()
                else: 
                    savepath = None
                self.get_rollout_metrics(session_pair, savepath)
        # dump experiment metrics to a json file (for reading as a pandas dataframe later)
        if self.config['log_per_session_metrics']:
            dfItem = pd.DataFrame({k: self.main_log[k] for k in ['prompt', 'situation', 'PE', 'TPE', 'BPE', 'SPL', 'LC']})
            dfItem.to_csv(Path(f"{self.name}_{self.config['test_split_filepath'].split('/')[-1]}_per_session_metrics.csv").as_posix()) #, index=False)
        
        metrics = self.get_avg_log(to_rollout)
        self.print(metrics, step=step)
        return metrics

    # @torch.no_grad()
    def get_per_session_metrics(self) -> None:
        self.sess_log = self.init_sess_log()
        self.model.eval()
        dataset = PromptSituationPickPlace(self.session_paths, self.pick_only)
        dataloader = DataLoader(dataset, batch_size=1, collate_fn=prompt_pad_fn, shuffle=False)
        with torch.no_grad():
            for inputs, targets in dataloader:
                # self.prompt = inputs["prompt"]
                out = self.model(**inputs, device=self.device)
                loss = self.criterion(out, targets)
                y_true = torch.cat(targets["action_instance"], dim=0)
                y_pred = torch.max(out["pick"], dim=1)[1]
                x_true = self.get_attr_based_predictions(inputs["situation"], y_true)
                x_pred = self.get_attr_based_predictions(inputs["situation"], y_pred)
                self.accumulate(loss, y_true, y_pred, x_true, x_pred)
            self.update_main_log()
        return

    def rollout_learned_policy(self, session_paths: Tuple[str, str], savepath: str):
        prompt, situation = session_paths
        obs = self.env.reset(ref_sess=situation)
        self.prompt_temporal_context = get_temporal_context(prompt)
        self.prompt_temporal_context.pick_only = True
        self.prompt = self.prompt_temporal_context.process_states()
        for key, val in self.prompt.items():
            self.prompt[key] = torch.tensor(val).unsqueeze(0)
        prompt_input_len = len(self.prompt["timestep"][0])
        self.prompt["src_key_padding_mask"] = (
            torch.zeros(prompt_input_len).bool().unsqueeze(0)
        )
        learned_policy = PromptSituationLearnedPolicy(
            model=self.model,
            pick_only=False, # self.config["pick_only"],
            device=self.config["device"],
        )
        learned_policy.reset(self.prompt)
        learned_rollout_info = session_rollout(obs, self.env, learned_policy, savepath)
        return learned_rollout_info

    def get_rollout_metrics(
        self, session_pair: List[Tuple[str, str]], savepath: str
    ) -> Dict:
        situation_session = session_pair[1]
        expert_rollout_info = self.get_expert_rollout_metrics(situation_session)
        learned_rollout_info = self.rollout_learned_policy(session_pair, savepath)
        metrics_record = self.compute_all_rollout_metrics(expert_rollout_info, learned_rollout_info)
        print(metrics_record)
        self.main_log['prompt'].append(session_pair[0])
        self.main_log['situation'].append(session_pair[1])
        # metrics_record.update({
            # 'prompt': session_pair[0],
            # 'situation': session_pair[1]
        # })
        # dump the rollout info per session in a file record for later processing
        # self.experiment_record.append(metrics_record)
        return metrics_record


class SessionPreferenceClassifierLogger(Logger):
    def __init__(
        self,
        env: Any,
        session_paths: List[str],
        config: Dict,
        model: PreTrainedModel,
        criterion: torch.nn.Module,
        name: str,
        pick_only: bool = True,
        max_evals: int = 10,
        savefolder: str = "rollouts",
    ) -> None:
        super().__init__(
            env, config, model, criterion, pick_only, max_evals, savefolder
        )
        self.session_paths = session_paths
        self.pick_only = pick_only
        self.name = name

    def __call__(self, model: torch.nn.Module, step: int, to_rollout: bool) -> Dict:
        self.main_log = self.init_main_log(to_rollout)
        self.model = model
        self.model.eval()
        dataset = SessionPreferenceDataset(self.session_paths, self.pick_only)
        dataloader = DataLoader(dataset, batch_size=16, collate_fn=preference_classifier_pad_fn, shuffle=True)
        y_true = []
        y_pred = []
        with torch.no_grad():
            for inputs, targets in dataloader:
                out = self.model(**inputs, device=self.device)
                loss = self.criterion(out, targets)
                y_true.append(targets)
                y_pred.append(torch.max(out, dim=1)[1])
                self.main_log['Loss'].append(loss.item())
        true = torch.cat(y_true).cpu().detach().numpy()
        pred = torch.cat(y_pred).cpu().detach().numpy()
        metrics = {
            f"{self.name}/Loss": np.mean(self.main_log['Loss']),
            f"{self.name}/Accuracy": accuracy_score(true, pred),
            f"{self.name}/ConfusionMatrix": wandb.plot.confusion_matrix(
                probs=None,
                y_true=true,
                preds=pred,
                class_names=list(dataset.preference_label_vocab.index2word(i) for i in range(len(dataset.preference_label_vocab))),
            )  #confusion_matrix(true, pred)
        }
        # self.print(metrics, step=step)
        return metrics