import pickle
from dataclasses import dataclass
from typing import List, Tuple, Union, Dict, Optional

import numpy as np
import pandas as pd
import torch
from pandas.core.groupby import DataFrameGroupBy


@dataclass
class Operator:
    name: str
    precondition: 'AbstractState'
    effects: List['AbstractState']
    probs: Optional[List[float]]
    rewards: Optional[List[float]]

    def to_pddl(self) -> str:
        precond = [f"({x})" for x in self.precondition.to_pddl()]
        if len(precond) > 1:
            precond = f"(and {' '.join(precond)})"
        else:
            precond = precond[0]
        nl = "\n"
        s = f"(:action {self.name}" + nl + \
            f"    :precondition {precond}" + nl + \
            f"    :effect {self.get_effects()})"
        return s

    def get_effects(self):
        filtered_effects = []
        precond = self.precondition.to_pddl()
        if self.rewards is not None:
            assert len(self.rewards) == len(self.effects)

        # remove predicates that are already in the precondition
        # and negate predicates that are over the same factor
        for i, eff in enumerate(self.effects):
            effs = []
            e_props = eff.to_pddl()
            for s1, s2 in zip(precond, e_props):
                if s1 != s2:
                    effs.append(f"({s2})")
                    effs.append(f"(not ({s1}))")
            # add rewards if they exist
            if self.rewards is not None:
                effs.append(f"(increase (reward) {self.rewards[i]:.2f})")
            if len(effs) > 1:
                effs = f"(and {' '.join(effs)})"
            elif len(effs) == 0:
                effs = "(and )"
            else:
                effs = effs[0]
            filtered_effects.append(effs)

        # add probabilities if they exist
        if self.probs is not None:
            assert len(filtered_effects) == len(self.probs)
            int_probs = [int(x*100) for x in self.probs]
            denom = sum(int_probs)
            if denom > 100:
                idx = np.argmax(int_probs)
                int_probs[idx] -= denom - 100
            elif denom < 100:
                idx = np.argmax(int_probs)
                int_probs[idx] += 100 - denom

            for i in range(len(filtered_effects)):
                if int_probs[i] == 100:
                    prob = "1.00"
                else:
                    prob = f"0.{int_probs[i]:02d}"
                filtered_effects[i] = f"{prob} {filtered_effects[i]}"
            nl = "\n"
            filtered_effects = f"(probabilistic{nl}        " + "\n        ".join(filtered_effects) + ")"
        else:
            filtered_effects = filtered_effects[0]
        return filtered_effects


# look into dataclass if need be later
@dataclass
class AbstractState:
    id: int
    initiation: Tuple[int, ...]
    # ensure that the order is same for other states
    factors: List['FactorValue']
    data: torch.Tensor
    action_counts: np.ndarray
    info: Optional[List[Dict]]

    def sample(self, n: int = 1) -> torch.Tensor:
        idx = torch.randint(0, len(self.data), (n,))
        return self.data[idx].clone()

    def to_pddl(self) -> list[str]:
        s = [str(fval) for fval in self.factors]
        return s

    def __hash__(self):
        f_ids = []
        for fval in self.factors:
            f_ids.extend([fval.factor.name, fval.id, fval.refinement])
        return hash((self.id,) + tuple(f_ids))

    def __eq__(self, other):
        if not isinstance(other, AbstractState):
            return False

        for fval1, fval2 in zip(self.factors, other.factors):
            assert fval1.factor.name == fval2.factor.name
            if fval1 != fval2:
                return False
        return True

    def __str__(self):
        return f"{self.id}"

    def __repr__(self):
        return self.__str__()

    def __lt__(self, other):
        return self.id < other.id


@dataclass
class FactorValue:
    id: int
    refinement: int
    factor: 'Factor'

    def __hash__(self):
        return hash((self.id, self.refinement))

    def __eq__(self, other):
        if isinstance(other, FactorValue):
            return self.id == other.id and \
                   self.refinement == other.refinement and \
                   self.factor.name == other.factor.name
        return False

    def __str__(self):
        return f"{self.factor.name}_{self.id}_{self.refinement}"

    def __repr__(self):
        return self.__str__()


@dataclass
class Factor:
    name: str
    variables: List[int]

    def __hash__(self):
        variables = tuple(sorted(self.variables))
        return hash(variables)

    def __eq__(self, other):
        return self.__hash__() == other.__hash__()

    def __str__(self):
        if len(self.variables) > 3:
            var_str = f"[{self.variables[0]}, ..., {self.variables[-1]}]"
        else:
            var_str = str(self.variables)
        return f"{self.name}({var_str})"

    def __repr__(self):
        return self.__str__()


class TransitionData:

    def __init__(self, load_from_pickle: Optional[str] = None):

        if load_from_pickle is not None:
            self._data = pd.read_pickle(load_from_pickle, compression='gzip')
        else:
            self._data = pd.DataFrame(
                columns=['episode', 'state', 'option', 'reward', 'next_state', 'done', 'mask', 'steps',
                         'options_available',
                         'next_options_available',
                         "state_info",
                         "next_state_info"])

    def add(self, episode: int, state: np.ndarray, option: int, reward: float, next_state: np.ndarray,
            done: bool, mask: np.ndarray, steps: int, options_available: np.ndarray,
            next_options_available: np.ndarray,
            state_info=None, next_state_info=None) -> None:
        """
        Add a transition to our frame
        :param episode: the episode number
        :param state: the state
        :param option: the option
        :param reward: the reward
        :param next_state: the next state
        :param done: whether the episode is terminated
        :param mask: the state variables that differ between s and s'
        :param steps: the number of steps taken by the option
        :param options_available: the set of options available at the state
        :param next_options_available: the set of options available at the next state
        """
        self._data.loc[self._data.shape[0]] = [episode, state, option,
                                               reward, next_state, done, mask, steps, options_available,
                                               next_options_available, state_info, next_state_info]

    def rows(self, all_data=False):
        for _, row in self._data.iterrows():
            if all_data:
                yield row
            else:
                yield row['state'], row['option'], row['reward'], row['next_state'], row['steps']

    def get_initial_states(self) -> pd.DataFrame:
        """
        Return the initial states and options available at those states
        """
        return self._data.groupby('episode').nth(0)  # type: ignore

    def groupby(self, column_name: Union[str, List[str]]) -> DataFrameGroupBy:
        return self._data.groupby(column_name)

    @property
    def data(self):
        return self._data

    def get_option_data(self, option: int) -> pd.DataFrame:

        return self._data.loc[self._data['option'] == option].reset_index(drop=True)

    def append(self, other: 'TransitionData'):
        self._data = pd.concat([self._data, other._data], ignore_index=True)

    def __len__(self):
        return len(self._data)

    @staticmethod
    def concat(frames: List['TransitionData']):
        data = frames[0]
        for i in range(1, len(frames)):
            data.append(frames[i])
        return data

    def to_pickle(self, filepath: str):
        self._data.to_pickle(filepath, compression='gzip')


class AbstractStateDict:
    """
    Comprise of three relations:
    abstract_state_key -> state_indices, next_state_indices,
    state_indices -> abstract_state_key, next_abstract_state_key,
    factor_values -> state_indices, next_state_indices.
    """
    def __init__(self, key_names, key_values, next_key_values, data):
        self._key_names = key_names
        self._data = data
        self._abs_to_ground = {}
        self._abs_w_act_to_ground = {}
        self._ground_to_abs = np.full((len(data), 2), None)
        self._factor_to_ground = {k: {} for k in key_names}
        self.initialize(key_values, next_key_values)
        self._abs_buffer = []
        self._ground_buffer = []

    def copy(self):
        copydict = AbstractStateDict(
            key_names=self._key_names.copy(),
            key_values={},
            next_key_values={},
            data=self._data
        )
        copydict._abs_to_ground = self._abs_to_ground.copy()
        copydict._abs_w_act_to_ground = self._abs_w_act_to_ground.copy()
        copydict._ground_to_abs = self._ground_to_abs.copy()
        for f in self._key_names:
            copydict._factor_to_ground[f] = self._factor_to_ground[f].copy()
        return copydict

    def add_to_buffer(self, s, a, r, sn, d, t, iv, ivn,
                      inf, infn, s_bar, sn_bar):
        self._ground_buffer.append(
            np.array([s, a, r, sn, d, t, iv, ivn, inf, infn], dtype=object)
        )
        self._abs_buffer.append((s_bar, sn_bar))

    def flush_buffer(self):
        ground_buffer = np.array(self._ground_buffer)
        offset = self._data.shape[0]
        self._data = np.concatenate([self._data, ground_buffer], axis=0)
        del self._ground_buffer[:]
        _ground_to_abs = np.full((len(self._abs_buffer), 2), None)
        new_states = {}
        for i, (s, s_) in enumerate(self._abs_buffer):
            idx = i + offset
            a = self._data[idx][1]

            if s != -1:
                s_key = tuple([(f.id, f.refinement) for f in s.factors])
            # if this is a new abstract state with a new initiation vector
            else:
                iv = self._data[idx, 6]
                if iv in new_states:
                    s_key = new_states[iv]
                else:
                    s_key = self._get_new_state_key()
                    new_states[iv] = s_key
                    self._abs_to_ground[s_key] = [[], []]

            self._abs_to_ground[s_key][0].append(idx)
            _ground_to_abs[i, 0] = s_key
            for j, f in enumerate(self._key_names):
                fv = s_key[j]
                if fv not in self._factor_to_ground[f]:
                    self._factor_to_ground[f][fv] = [[], []]

                self._factor_to_ground[f][fv][0].append(idx)

            if s_key not in self._abs_w_act_to_ground:
                self._abs_w_act_to_ground[s_key] = {}
            if a not in self._abs_w_act_to_ground[s_key]:
                self._abs_w_act_to_ground[s_key][a] = []
            self._abs_w_act_to_ground[s_key][a].append(idx)

            if s_ != -1:
                sn_key = tuple([(f.id, f.refinement) for f in s_.factors])
            # if this is a new abstract state with a new initiation vector
            else:
                iv_ = self._data[idx, 7]
                if iv_ in new_states:
                    sn_key = new_states[iv_]
                else:
                    sn_key = self._get_new_state_key()
                    new_states[iv_] = sn_key
                    self._abs_to_ground[sn_key] = [[], []]

            self._abs_to_ground[sn_key][1].append(idx)
            _ground_to_abs[i, 1] = sn_key
            for j, f in enumerate(self._key_names):
                fv_ = sn_key[j]
                if fv_ not in self._factor_to_ground[f]:
                    self._factor_to_ground[f][fv_] = [[], []]
                self._factor_to_ground[f][fv_][1].append(idx)

        self._ground_to_abs = np.concatenate([self._ground_to_abs, _ground_to_abs], axis=0)
        del self._abs_buffer[:]

    def get_keys(self):
        return list(self._abs_to_ground.keys())

    def get_ground_samples(self, abstract_state_key):
        idx, next_idx = self._abs_to_ground[abstract_state_key]
        return self._data[idx], self._data[next_idx]

    def get_abstract_states(self, ground_indices):
        return self._ground_to_abs[ground_indices]

    def get_transitions(self, abstract_state_key):
        transitions = {}
        if abstract_state_key not in self._abs_w_act_to_ground:
            return {}

        for a in self._abs_w_act_to_ground[abstract_state_key]:
            transitions[a] = self._data[self._abs_w_act_to_ground[abstract_state_key][a]]
        return transitions

    def get_abstract_transitions(self, abstract_state_key, min_samples=10):
        transitions = {}
        if abstract_state_key not in self._abs_w_act_to_ground:
            return transitions

        for a in self._abs_w_act_to_ground[abstract_state_key]:
            if len(self._abs_w_act_to_ground[abstract_state_key][a]) < min_samples:
                continue

            transitions[a] = {}
            for i in self._abs_w_act_to_ground[abstract_state_key][a]:
                row = self._data[i]
                next_state = self._ground_to_abs[i][1]
                if next_state not in transitions[a]:
                    transitions[a][next_state] = [0, 0, 0]
                transitions[a][next_state][0] += 1
                transitions[a][next_state][1] += row[2]
                transitions[a][next_state][2] += row[5]

        # normalize
        for a in transitions:
            denom = sum([x[0] for x in transitions[a].values()])
            for next_state in transitions[a]:
                c = transitions[a][next_state][0]
                transitions[a][next_state][1] /= c
                transitions[a][next_state][2] /= c
                transitions[a][next_state][0] /= denom
        return transitions

    def get_factor_values(self, factor, value):
        idx, next_idx = self._factor_to_ground[factor][value]
        return self._data[idx], self._data[next_idx]

    def initialize(self, key_values, next_key_values):
        assert len(key_values) == len(next_key_values)
        for i, (key, next_key) in enumerate(zip(key_values, next_key_values)):
            action = self._data[i][1]
            if key not in self._abs_to_ground:
                self._abs_to_ground[key] = [[], []]
            if next_key not in self._abs_to_ground:
                self._abs_to_ground[next_key] = [[], []]
            if key not in self._abs_w_act_to_ground:
                self._abs_w_act_to_ground[key] = {}
            if action not in self._abs_w_act_to_ground[key]:
                self._abs_w_act_to_ground[key][action] = []

            self._abs_to_ground[key][0].append(i)
            self._abs_to_ground[next_key][1].append(i)
            self._abs_w_act_to_ground[key][action].append(i)
            self._ground_to_abs[i][0] = key
            self._ground_to_abs[i][1] = next_key

            for j, (k_j, nk_j) in enumerate(zip(key, next_key)):
                jname = self._key_names[j]

                if k_j not in self._factor_to_ground[jname]:
                    self._factor_to_ground[jname][k_j] = [[], []]
                if nk_j not in self._factor_to_ground[jname]:
                    self._factor_to_ground[jname][nk_j] = [[], []]

                self._factor_to_ground[jname][k_j][0].append(i)
                self._factor_to_ground[jname][nk_j][1].append(i)

    def save(self, path):
        with open(path, "wb") as f:
            pickle.dump(self, f)

    def _get_new_state_key(self):
        key = []
        for f in self._key_names:
            max_id = -1
            for fv in self._factor_to_ground[f]:
                if max_id < fv[0]:
                    max_id = fv[0]
            key.append((max_id+1, 1))
        return tuple(key)
