import numpy as np
import torch

from itertools import product
from torch.utils.data import TensorDataset, DataLoader
from abc import (
    ABCMeta,
    abstractmethod,
)
import d3rlpy
from d3rlpy.datasets import (
    MDPDataset,
    ReplayBuffer,
    InfiniteBuffer,
    TransitionPickerProtocol,
)
from d3rlpy.models.encoders import PixelEncoderFactory
from d3rlpy.dataset.components import Transition
from d3rlpy.dataset.utils import create_zero_observation

from relign.helpers import get_device

from lift.environments import BaseEnvironment
from lift.network import ActionProbabilityModel, SimpleEncoderFactory
from lift.callbacks import CustomEnvironmentEvaluator
from lift.expert_policy import OptimalPolicy



import os

class Shortcuts(TransitionPickerProtocol):

    def __init__(self, action_low=-0.2, action_height=0.2, select='weighted'):
        self.action_low=action_low
        self.action_height=action_height
        self.select=select
        self.valid = []

    def __call__(self, episode, index: int) -> Transition:

        observation = episode.observations[index]
        is_terminal = episode.terminated and index == episode.size() - 1
        eps=1e-4
        if is_terminal:
            next_observation = create_zero_observation(observation)
            next_action = np.zeros_like(episode.actions[index])
            next_idx = index
            #action = np.zeros_like(episode.actions[index])
            action = episode.actions[index]
        else:
            cumsum_action = np.cumsum(episode.actions[index:], axis=0)
            valid_action = np.all(cumsum_action>self.action_low-eps, axis=1) & np.all(cumsum_action<=self.action_height+eps, axis=1)
            valid_idx = np.where(valid_action)[0][1:]
            if len(valid_idx)<2:
                self.valid.append(0)

                chosen_rel=1
            elif self.select == "random_floor":
                target = np.random.randint(1, episode.size()-index)
                chosen_rel = np.searchsorted(valid_idx, target, side='right') - 1

            elif self.select == "random":
                chosen_rel = int(np.random.choice(valid_idx))

            elif self.select == "best":
                rewards_slice = episode.rewards[index:]
                eval_idx = valid_idx - 1
                valid_rewards = rewards_slice[eval_idx]
                best_pos = int(np.argmax(valid_rewards))
                chosen_rel = int(valid_idx[best_pos])

            elif self.select == "weighted":
                rewards_slice = episode.rewards[index:]
                eval_idx = valid_idx - 1
                valid_rewards = rewards_slice[eval_idx]

                probs = valid_rewards - valid_rewards.min() + 1e-6
                s = probs.sum()
                if s <= 0:
                    probs = np.full_like(probs, 1.0 / probs.size, dtype=float)
                else:
                    probs = probs / s

                chosen_rel = int(np.random.choice(valid_idx, p=probs.reshape(-1)))

            elif self.select == "dis_weighting":
                rewards_slice = episode.rewards[index:]
                eval_idx = valid_idx - 1
                valid_rewards = rewards_slice[eval_idx]
                probs = -1/valid_rewards
                s = probs.sum()
                if s <= 0:
                    probs = np.full_like(probs, 1.0 / probs.size, dtype=float)
                else:
                    probs = probs / s

                chosen_rel = int(np.random.choice(valid_idx, p=probs.reshape(-1)))
            else:
                raise ValueError("select must be 'random', 'best', or 'weighted'")

            next_idx = chosen_rel + index
            next_observation = episode.observations[next_idx]
            action = episode.actions[index:next_idx].sum(axis=0)
            next_action = episode.actions[next_idx:].sum(axis=0)

        return Transition(
            observation=observation,
            action=action,
            reward=episode.rewards[next_idx-1],
            next_observation=next_observation,
            next_action=next_action,
            terminal=float(is_terminal),
            interval=1,
            rewards_to_go=episode.rewards[next_idx-1:],
        )


class PolicyAugmentation(metaclass=ABCMeta):

    def __init__(self, policy, env, env_eval, augment_ratio=1, n_augmentations=1):

        self.policy = policy
        assert 0 <= augment_ratio <= 1, 'augment_ratio must be in [0, 1]'
        self.augment_ratio = augment_ratio

        self.n_augmentations = n_augmentations

        if isinstance(env, BaseEnvironment):
            raise ValueError('Gymn env required')
        self.env = env
        self.env_eval = env_eval

        self.init()

    @property
    def n_actions(self):
        return self.env.action_space.shape[0]

    def init(self):
        self.observations = np.empty(shape=0)
        self.actions = np.empty(shape=0)
        self.rewards = np.empty(shape=0)
        self.terminals = np.empty(shape=0)
        self.positions = np.empty(shape=0)
        self.is_augmented = np.empty(shape=0)

    @abstractmethod
    def follow_policy(self, obs, action):
        """
        Takes an observation and proposed action and returns `True` or `False` whether policy should
        be followed or not
        """
        return True

    def augment(self, obs, action_proposed):
        """
        Takes an observation and proposed action and returns an augmented version
        """
        return action_proposed

    def update(self):
        """
        Updates the augmentor
        """
        pass

    def combine_actions(self, action_augmentation, action_proposed, with_norm=True):
        """
        The weighted mean of the direction and the weighted mean of the length are formed separately
        """
        if with_norm:
            norm1 = np.max([np.linalg.norm(action_augmentation), 1e-6])  # Avoid division by zero
            norm2 = np.max([np.linalg.norm(action_proposed), 1e-6])
            mixed_norm = self.augment_ratio * norm1 + (1 - self.augment_ratio) * norm2

            dir1 = action_augmentation / norm1
            dir2 = action_proposed / norm2

            mixed_dir = self.augment_ratio * dir1 + (1 - self.augment_ratio) * dir2
            mixed_dir = mixed_dir / np.max([np.linalg.norm(mixed_dir), 1e-6])

            return mixed_norm * mixed_dir
        else:
            return self.augment_ratio * action_augmentation + (1 - self.augment_ratio) * action_proposed

    def _clip(self, action):
        return  np.clip(action, self.env.action_space.low, self.env.action_space.high)

    def record(self, observations, actions, rewards, terminals, positions, is_augmented):
        observations = np.array(observations)
        actions = np.array(actions)
        rewards = np.array(rewards)
        terminals = np.array(terminals)
        positions = np.array(positions)
        is_augmented = np.array(is_augmented)


        if len(self.observations) == 0:
            self.observations = observations
        else:
            self.observations = np.vstack([self.observations, observations])

        if len(self.actions) == 0:
            self.actions = actions
        else:
            self.actions = np.vstack([self.actions, actions])

        if len(self.rewards) == 0:
            self.rewards = rewards
        else:
            self.rewards = np.concatenate([self.rewards, rewards])

        if len(self.terminals) == 0:
            self.terminals = terminals
        else:
            self.terminals = np.concatenate([self.terminals, terminals])

        if len(self.positions) == 0:
            self.positions = positions
        else:
            self.positions = np.concatenate([self.positions, positions])

        if len(self.is_augmented) == 0:
            self.is_augmented = is_augmented
        else:
            self.is_augmented = np.concatenate([self.is_augmented, is_augmented])

    def run(self, obs):

        observations = []
        actions = []
        rewards = []
        terminals = []
        positions = []
        is_augmented = []
        n_augmentations = 0

        self.policy.restart(obs)

        done = False
        while not done:

            observations.append(obs)
            positions.append(self.env.get_position_diff_to_optimum())
            action_proposed = self.policy.predict(obs)
            # Decide whether to do policy step
            if not self.follow_policy(obs, action_proposed) and n_augmentations < self.n_augmentations:
                action_augmented = self.combine_actions(self.augment(obs, action_proposed), action_proposed)
                action_augmented = self._clip(action_augmented)

                n_augmentations = n_augmentations + 1
                obs, reward, truncated, terminated, _ = self.env.step(action_augmented)
                actions.append(action_augmented)
                self.policy.restart(obs)
                is_augmented.append(True)

            else:
                action_proposed = self._clip(action_proposed)
                obs, reward, truncated, terminated, _ = self.env.step(action_proposed)
                actions.append(action_proposed)
                is_augmented.append(False)

            done = truncated or terminated
            rewards.append(reward)
            terminals.append(done)

        self.record(observations, actions, rewards, terminals, positions, is_augmented)

        self.update()

        return observations, actions, rewards, terminals


class VoidAugmentation(PolicyAugmentation):

    def follow_policy(self, *args):
        return True

class StepSizeAugmentation(PolicyAugmentation):

    def __init__(self, *args, p=0.1, v=0.12, **kwargs):
        super().__init__(*args, **kwargs)
        self.p = p
        self.sigma = np.sqrt(v)

    def follow_policy(self, *args):
        return np.random.uniform(0, 1) > self.p

    def augment(self, obs, action_proposed):
        # Choose variance=0.12 (sigma ≈ 0.347) so that ~95% of multiplicative noise samples (e^{eta})
        # fall between 0.5 and 2.0, because log(0.5) ≈ -0.693 and log(2.0) ≈ +0.693,

        eta = np.random.normal(loc=0.0, scale=self.sigma, size=self.n_actions)
        return action_proposed * np.exp(eta)*2


class GaussianAugmentation(PolicyAugmentation):

    def __init__(self, *args, p=0.1, **kwargs):
        super().__init__(*args, **kwargs)
        self.p = p

    def follow_policy(self, *args):
        return np.random.uniform(0, 1) > self.p

    def augment(self, obs, action_proposed):
        return action_proposed + np.random.normal(loc=0, scale=0.05, size=self.n_actions)



class RandomAugmentation(PolicyAugmentation):

    def __init__(self, *args, p=0.2, **kwargs):
        super().__init__(*args, **kwargs)
        self.p = p

    def follow_policy(self, *args):
        return np.random.uniform(0, 1) > self.p

    def augment(self, obs, action_proposed):
        return self.env.action_space.sample()

class OptimalAugmentation(PolicyAugmentation):

    def __init__(self, *args, p=0.2,step_norm_factor=1, **kwargs):
        super().__init__(*args, **kwargs)
        self.p = p
        self.step_norm_factor=step_norm_factor

        self.optimal_policy = OptimalPolicy(self.env, max_step_length=np.inf)


    def follow_policy(self, *args):
        return np.random.uniform(0, 1) > self.p

    def augment(self, obs, action_proposed):
        action = self.optimal_policy.predict(obs)
        if self.step_norm_factor > 0:
                return self.step_norm_factor*np.linalg.norm(action_proposed) * action/np.linalg.norm(action)
        return action


class IORLAugmentation(PolicyAugmentation):

    def __init__(
            self,
            *args,
            learning_rate=0.001,
            train_epochs=20,
            min_data=50,
            p=0.4,
            observation_shape=(1, 50, 50),
            hidden_dim=128,
            device=None,
            **kwargs,
        ):
        self.learning_rate = learning_rate
        self.train_epochs = train_epochs
        self.min_data = min_data
        self.observation_shape = observation_shape
        self.hidden_dim = hidden_dim
        self.p = p

        if device is None:
            device = get_device()
        self.device = device

        super().__init__(*args, **kwargs)

    def init(self):
        super().init()

        self.model = None

        self.action_grid = self._generate_action_grid().to(self.device)

        self.losses = []

    def _generate_action_grid(self, p=0.05):
        values = np.arange(
            self.env.action_space.low[0],
            self.env.action_space.high[0] + p/2,
            p,
        )
        return torch.Tensor(np.array(list(product(values, repeat=self.n_actions))))

    def _train_model(self):

        self.model = ActionProbabilityModel(
            n_actions=self.n_actions,
            observation_shape=self.observation_shape,
            hidden_dim=self.hidden_dim,
            ).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

        self.model.train()

        actions = torch.Tensor(self.actions).to(self.device)
        observations = torch.Tensor(self.observations).to(self.device)

        dataset = TensorDataset(observations, actions)
        loader = DataLoader(dataset, batch_size=10, shuffle=True)

        losses = []

        for _ in range(self.train_epochs):

            iters = 0
            running_loss = 0

            for obs, act in loader:
                iters = iters + 1

                self.optimizer.zero_grad()

                log_probs = self.model.log_prob(obs, act)
                loss = -log_probs.mean()

                loss.backward()

                running_loss = running_loss + loss.detach().cpu().item()

                self.optimizer.step()

            losses.append(running_loss/iters)
            print(running_loss/iters)

        self.losses.append(losses)

    def update(self):

        if self.terminals.sum() == self.min_data:
            self._train_model()

    def augment(self, obs, action_proposed):
        obs = torch.Tensor(obs).unsqueeze(0).to(self.device)
        if self.model is None:
            return action_proposed

        with torch.no_grad():
            mean, log_std = self.model(obs)

            log_probs = self.model.gmm.log_prob(self.action_grid, mean, log_std)
            log_probs = log_probs.detach().cpu().numpy()
            log_probs =  np.clip(log_probs, -5, 999)

        probs = np.exp(log_probs)**(-0.1)
        probs = probs / probs.sum()

        rand_idx = np.random.choice(np.arange(len(probs)), p=probs)
        return self.action_grid[rand_idx].detach().cpu().numpy()

    def follow_policy(self, obs, action):
        """
        Follow policy only if log-proba of action is not above threshold
        """
        if self.model is None:
            return True
        else:
            return np.random.uniform(0, 1) > self.p


class LIFTAugmentation(PolicyAugmentation):
    def __init__(self, *args, learning_rate=0.001, min_data=10, p=-1,
                 model_class='SAC',
                 load_model_fn=None, use_shortcuts=False, train_epochs=40, train_once=True, step_norm_factor=1, **kwargs):
        self.learning_rate = learning_rate
        self.device = get_device()
        self.p = p
        self.use_shortcuts = use_shortcuts
        self.min_data = min_data
        self.load_model_fn = load_model_fn
        self.train_epochs = train_epochs
        self.train_once = train_once
        self.step_norm_factor = step_norm_factor
        self.model_class = model_class

        super().__init__(*args, **kwargs)



    def _make_model(self):

        if len(self.observations.shape)==2:
            encoder_factory = SimpleEncoderFactory(10)
            self.device = 'cpu'
        else:
            encoder_factory = PixelEncoderFactory()
            self.device = 'cuda:0'


        if self.model_class == 'CQL':
            model = d3rlpy.algos.CQLConfig(
                actor_encoder_factory=encoder_factory,
                critic_encoder_factory=encoder_factory,
                batch_size=500,
                actor_learning_rate=1e-3,
                critic_learning_rate=1e-3,
                alpha_threshold=10,
                conservative_weight=5 #0.5, # Default is 5
            ).create(device=self.device)
        elif self.model_class == 'SAC':
            model = d3rlpy.algos.SACConfig(
                actor_encoder_factory=encoder_factory,
                critic_encoder_factory=encoder_factory,
                batch_size=400,
                actor_learning_rate=0.001,
                critic_learning_rate=0.003,
            ).create(device='cpu')
        elif self.model_class == 'IQL':
            model = d3rlpy.algos.IQLConfig(
                actor_encoder_factory=encoder_factory,
                critic_encoder_factory=encoder_factory,
                batch_size=400,
                actor_learning_rate=0.001,
                critic_learning_rate=0.003,
            ).create(device='cpu')
        elif self.model_class == 'DQN':
            model = d3rlpy.algos.DQNConfig(
                batch_size=400,

            ).create(device='cpu')
        else:
            raise ValueError(f'Unknown model class: {self.model_class}')
        return model

    def init(self):
        super().init()
        self.threshold = torch.inf
        self.model = None
 
    def update(self):
        if self.terminals.sum() == self.min_data:
            self._train_model(train_epochs=self.train_epochs)
        elif self.terminals.sum() > self.min_data and not self.train_once:
            self._train_model(train_epochs=1)

    def _train_model(self, train_epochs):
        if self.model is None:
            self.model = self._make_model()

        dataset = MDPDataset(
            observations=self.observations,
            actions=self.actions,
            rewards=self.rewards,
            terminals=self.terminals,
        )

        dataset = ReplayBuffer(
            buffer=InfiniteBuffer(),
            transition_picker=Shortcuts() if self.use_shortcuts else None,
            episodes=dataset.episodes,
        )
        args = {}
        n_steps_per_epoch = 100

        logger_adapter = d3rlpy.logging.FileAdapterFactory(root_dir =
                                                           os.path.join(os.environ["WORKING_DIR"], 'lift', 'd3rlpy_logs'))
        logger_adapter = d3rlpy.logging.CombineAdapterFactory([
                d3rlpy.logging.FileAdapterFactory(root_dir=os.path.join(os.environ["WORKING_DIR"], 'lift', 'd3rlpy_logs')),
                d3rlpy.logging.TensorboardAdapterFactory(root_dir=os.path.join(os.environ["WORKING_DIR"], 'lift', 'tf_augmentation')),
                ])

        self.model.fit(
            dataset,
            logger_adapter=logger_adapter,
            n_steps=train_epochs*n_steps_per_epoch,
            n_steps_per_epoch=n_steps_per_epoch,
            show_progress=False,
            evaluators={
            'environment': CustomEnvironmentEvaluator(self.env_eval, steps=15, n_trials=5)
            },
            **args
        )

    def augment(self, obs, action_proposed):

        if self.model is None:
            return action_proposed
        else:
            action = self.model.predict(obs[np.newaxis,...])[0]
            if self.step_norm_factor > 0:
                return self.step_norm_factor*np.linalg.norm(action_proposed) * action/np.linalg.norm(action)

            return action

    def follow_policy(self, obs, action):
        """
        Follow policy only if log-proba of action is not above threshold
        """
        if self.model is None:
            return True
        else:
            if self.p > 0:
                return np.random.uniform(0, 1) > self.p
            else:
                action_predict = self.model.predict(obs[np.newaxis, ...])[0]
                q_policy = self.model.predict_value(obs[np.newaxis, ...], action[np.newaxis, ...])[0]
                q_sac = self.model.predict_value(obs[np.newaxis, ...], action_predict[np.newaxis, ...])[0]
                return q_policy < q_sac
