from dataclasses import dataclass
from typing import Any, Literal, cast, Union, TypeVar

import torch
import gym
import numpy as np

from tianshou.data import Batch, SequenceSummaryStats
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import (
    ModelOutputBatchProtocol,
    ObsBatchProtocol,
    RolloutBatchProtocol,
)
from tianshou.policy import BasePolicy, DQNPolicy, PPOPolicy, SACPolicy
from tianshou.policy.base import TLearningRateScheduler, TTrainingStats, TrainingStats
from trainer.net import ActorLM, ActorLM_API, DoubleLM
from simglucose.analysis.risk import risk_index
from utils.prompts import (Conversation, obs2text, text2act, SYSTEM_PROMPT, SUMMARY_INSTRUCTION_PROMPT,
                           LLM_INFERENCE_INSTRUCTION_PROMPT, LLM_INFERENCE_RETRY_PROMPT, get_patient_info_prompt)
from utils.misc import process_tensor, unzip, truncate


@dataclass(kw_only=True)
class RLHFTrainingStats(TrainingStats):
    policy_loss: SequenceSummaryStats


TRLHFTrainingStats = TypeVar("TRLHFTrainingStats", bound=RLHFTrainingStats)


class LLM_Policy(BasePolicy):
    """
    Implementation of pure LLM policy.
    """

    def __init__(
            self,
            model: Union[ActorLM, ActorLM_API],
            action_space: gym.Space,
            observation_space: gym.Space | None = None,
            need_summary: bool = False,
            need_meta_info: bool = False,
            num_try: int = 1,
    ) -> None:
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            action_scaling=False,
            action_bound_method=None,
        )
        self.model = model
        self.need_summary = need_summary
        self.need_meta_info = need_meta_info
        self.num_try = num_try
        if num_try < 1:
            raise ValueError("num_try should be greater than 0")

        self.meta_info_fn = get_patient_info_prompt if need_meta_info else lambda *args: ""

    def forward(
            self,
            batch: ObsBatchProtocol,
            state: dict | BatchProtocol | np.ndarray | None = None,
            model: Literal["model", "model_old"] = "model",
            **kwargs: Any,
    ) -> ModelOutputBatchProtocol:
        """Decide action over the given batch data."""
        if batch.obs.shape[0] != 1:
            raise ValueError("LLM_Policy only supports batch size of 1 at inference time.")
        model = getattr(self, model)

        obs_prompt = obs2text(batch[0])
        meta_prompt = self.meta_info_fn(batch[0].info["Age"],
                                        batch[0].info["CR"],
                                        batch[0].info["CF"],
                                        batch[0].info["TDI"])
        messages = Conversation()
        messages.insert_component("system", SYSTEM_PROMPT + meta_prompt, 0)
        if self.need_summary and (batch.obs[:, :, 0] == -1).mean() < 0.8:
            messages.insert_component("user", obs_prompt + SUMMARY_INSTRUCTION_PROMPT, -1)
            summary = model.forward(messages.get())
            messages.insert_component("assistant", summary, -1)
            messages.insert_component("user", LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
            action_text = model.forward(messages.get())
        else:
            messages.insert_component("user", obs_prompt + LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
            action_text = model.forward(messages.get())

        use_random = True
        for _ in range(self.num_try):
            act = text2act(action_text, self.action_space)
            if act is not None:
                use_random = False
                break
            messages.insert_component("assistant", action_text, -1)
            messages.insert_component("user", LLM_INFERENCE_RETRY_PROMPT, -1)
            action_text = model.forward(messages.get())

        # use random action if no valid action is found
        if use_random:
            act = self.action_space.sample()

        result = Batch(act=[act], state=state)
        return cast(ModelOutputBatchProtocol, result)

    def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTrainingStats:
        raise NotImplementedError("LLM_Policy does not support learning.")


class LLM_Instruct_Policy(LLM_Policy):
    """
    Implementation of LLM Instruct policy.
    """

    def __init__(
            self,
            model: DoubleLM,
            optim: torch.optim.Optimizer,
            action_space: gym.Space,
            eps_clip: float = 0.2,
            ent_coef: float = 0.01,
            kl_coef: float = 0.01,
            max_grad_norm: float | None = None,
            gamma: float = 1,
            gae_lambda: float = 0.95,
            observation_space: gym.Space | None = None,
            need_meta_info: bool = False,
            num_try: int = 1,
            lr_scheduler: TLearningRateScheduler | None = None,
    ) -> None:
        super().__init__(
            model=model,
            action_space=action_space,
            observation_space=observation_space,
            need_meta_info=need_meta_info,
            num_try=num_try,
        )
        assert 0.0 <= gamma <= 1.0, f"Discount factor gamma should be in [0, 1] but got: {gamma}"
        self.gamma = gamma
        assert 0.0 <= gae_lambda <= 1.0, f"GAE lambda should be in [0, 1] but got: {gae_lambda}"
        self.gae_lambda = gae_lambda
        self.eps_clip = eps_clip
        self.ent_coef = ent_coef
        self.kl_coef = kl_coef
        self.max_grad_norm = max_grad_norm

        self.optim = optim
        self.lr_scheduler = lr_scheduler

        # freeze all model parameters except instruct_lm
        for name, param in self.model.named_parameters():
            if "instruct_lm" not in name:
                param.requires_grad = False
    
    def kl_divergence(self, old_log_probs, new_log_probs):
        kl_div = torch.exp(old_log_probs) / torch.exp(new_log_probs) - old_log_probs / new_log_probs - 1
        return kl_div
    
    def entropy(self, log_probs):
        ent = -torch.sum(log_probs * torch.exp(log_probs), dim=1)
        return ent
    
    def compute_reward(self, bg_next):
        X_MAX, X_MIN = 0, -100
        r = -risk_index([bg_next], 1)[-1]
        rew = ((r - X_MIN) / (X_MAX - X_MIN))
        if bg_next < 40:
            risk_reward = -15
        else:
            risk_reward = rew

        insulin_penalty = 0
        reward = risk_reward + insulin_penalty

        return reward

    def get_adv(self, rewards, old_values):
        returns = []
        advantages = []
        last_gae_lambda = 0
        last_value = 0
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * last_value - old_values[step]
            last_gae_lambda = delta + self.gamma * self.gae_lambda * last_gae_lambda
            advantages.insert(0, last_gae_lambda)
            last_value = old_values[step]
            returns.insert(0, last_gae_lambda + old_values[step])

        # normalize advantages
        advantages = torch.tensor(advantages, dtype=torch.float32).to(rewards.device)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        returns = torch.tensor(returns, dtype=torch.float32).to(rewards.device)

        return advantages, returns
    
    def forward(
            self,
            batch: ObsBatchProtocol,
            state: dict | BatchProtocol | np.ndarray | None = None,
            model: Literal["model", "model_old"] = "model",
            group_num: int | None = None,
            **kwargs: Any,
    ) -> ModelOutputBatchProtocol:
        """Decide action over the given batch data."""
        if batch.obs.shape[0] != 1:
            raise ValueError("LLM_Policy only supports batch size of 1 at inference time.")
        model = getattr(self, model)
        if group_num == None:
            log_probs = torch.empty((512, 2 * self.model.instruct_lm.top_k))

            obs_prompt = obs2text(batch[0])
            meta_prompt = self.meta_info_fn(batch[0].info["Age"],
                                            batch[0].info["CR"],
                                            batch[0].info["CF"],
                                            batch[0].info["TDI"])
            messages = Conversation()
            messages.insert_component("system", SYSTEM_PROMPT + meta_prompt, 0)
            if (batch.obs[:, :, 0] == -1).mean() < 0.8:
                messages.insert_component("user", obs_prompt + SUMMARY_INSTRUCTION_PROMPT, -1)
                summary, log_probs = model.forward(messages.get(), mode='instruct')
                messages.insert_component("assistant", summary, -1)
                messages.insert_component("user", LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
                action_text = model.forward(messages.get(), mode='actor')
            else:
                messages.insert_component("user", obs_prompt + LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
                action_text = model.forward(messages.get(), mode='actor')

            use_random = True
            for _ in range(self.num_try):
                act = text2act(action_text, self.action_space)
                if act is not None:
                    use_random = False
                    break
                messages.insert_component("assistant", action_text, -1)
                messages.insert_component("user", LLM_INFERENCE_RETRY_PROMPT, -1)
                action_text = model.forward(messages.get(), mode='actor')

            # use random action if no valid action is found
            if use_random:
                act = self.action_space.sample()

            result = Batch(act=[act], state=state, log_probs=[log_probs])

        else:
            log_probs = torch.empty((group_num, 512, 2 * self.model.instruct_lm.top_k))

            obs_prompt = obs2text(batch[0])
            meta_prompt = self.meta_info_fn(batch[0].info["Age"],
                                            batch[0].info["CR"],
                                            batch[0].info["CF"],
                                            batch[0].info["TDI"])
            messages = Conversation()
            messages.insert_component("system", SYSTEM_PROMPT + meta_prompt, 0)
            if (batch.obs[:, :, 0] == -1).mean() < 0.8:
                messages.insert_component("user", obs_prompt + SUMMARY_INSTRUCTION_PROMPT, -1)
                summary, log_probs = model.forward([messages.get()] * group_num, mode='instruct')
                messages = [messages] * group_num
                for i in range(group_num):
                    messages[i].insert_component("assistant", summary[i], -1)
                    messages[i].insert_component("user", LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
                _messages_get = [messages[i].get() for i in range(group_num)]
                action_text = model.forward(_messages_get, mode='actor')

                use_random = [True] * group_num
                acts = [None] * group_num
                for _ in range(self.num_try):
                    for i in range(group_num):
                        act = text2act(action_text[i], self.action_space)
                        if act is not None:
                            use_random[i] = False
                            acts[i] = act
                            break
                        messages[i].insert_component("assistant", action_text, -1)
                        messages[i].insert_component("user", LLM_INFERENCE_RETRY_PROMPT, -1)
                    _messages_get = [messages[i].get() for i in range(group_num) if use_random[i]]
                    text, text_index = model.forward(_messages_get, mode='actor'), 0
                    for i in range(group_num):
                        if use_random[i]:
                            action_text[i] = text[text_index]
                            text_index += 1

                # use random action if no valid action is found
                for i in range(group_num):
                    if use_random[i]:
                        acts[i] = self.action_space.sample()

                result = Batch(act=[acts], state=state, log_probs=[log_probs])

            else:
                messages.insert_component("user", obs_prompt + LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
                action_text = model.forward(messages.get(), mode='actor')

                use_random = True
                for _ in range(self.num_try):
                    act = text2act(action_text, self.action_space)
                    if act is not None:
                        use_random = False
                        break
                    messages.insert_component("assistant", action_text, -1)
                    messages.insert_component("user", LLM_INFERENCE_RETRY_PROMPT, -1)
                    action_text = model.forward(messages.get(), mode='actor')

                # use random action if no valid action is found
                if use_random:
                    act = self.action_space.sample()

                result = Batch(act=[[act] * group_num], state=state, log_probs=[log_probs])

        return cast(ModelOutputBatchProtocol, result)

    def learn(  # type: ignore
        self,
        batch: RolloutBatchProtocol,
        batch_size: int | None,
        repeat: int,
        *args: Any,
        **kwargs: Any,
    ) -> TTrainingStats:
        """Learn function to update the InstructLM using PPO based on the rewards from the batch."""
        total_policy_losses = []
        split_batch_size = batch_size or -1

        for _ in range(repeat):
            for mini_batch in batch.split(split_batch_size, merge_last=True):
                policy_losses = []
                for bat in mini_batch:
                    # check skipping condition
                    group_num, log_probs_sample = len(bat.log_probs), process_tensor(bat.log_probs[0])
                    if log_probs_sample == None: # in the first few steps of an episode, no summaries
                        continue

                    # get new model output
                    messages = Conversation()
                    obs_prompt = obs2text(bat)
                    messages.insert_component("system", SYSTEM_PROMPT, 0)
                    messages.insert_component("user", obs_prompt + LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
                    _, new_log_probs = self.model.forward(messages.get(), mode='instruct')
                    new_log_probs = process_tensor(new_log_probs)

                    # get group rewards
                    bg_nexts, rews = bat.obs_next[:, -1, 0], []
                    for i in range(group_num):
                        rews.append(self.compute_reward(bg_nexts[i]))

                    # group compute
                    policy_loss = 0
                    for i in range(group_num):
                        old_log_probs = process_tensor(bat.log_probs[i])
                        new_log_probs, old_log_probs = truncate(unzip(new_log_probs, self.model.instruct_lm.top_k, self.model.instruct_lm.vocab_size),
                                                                unzip(old_log_probs, self.model.instruct_lm.top_k, self.model.instruct_lm.vocab_size))

                        # compute group rewards
                        rewards, values = torch.zeros(len(old_log_probs)), torch.zeros(len(old_log_probs))
                        rewards[-1] = rews[i]  # credit assigned to last token
                        kl_penalty = self.kl_divergence(old_log_probs, new_log_probs) * self.kl_coef
                        rewards = rewards - kl_penalty

                        # compute policy loss
                        advantages, _ = self.get_adv(rewards, values)   # group-relative advantage
                        ratio = torch.exp(new_log_probs - old_log_probs)
                        ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                        surr1 = ratio * advantages
                        surr2 = torch.clamp(ratio, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advantages
                        entropies = self.entropy(new_log_probs)
                        policy_loss += (-torch.min(surr1, surr2) - self.ent_coef * entropies).mean().item()
                    
                    # average objective
                    policy_losses.append(policy_loss / group_num)
                
                if policy_losses != []:
                    self.optim.zero_grad()
                    policy_loss = torch.tensor(policy_losses, requires_grad=True).mean()
                    policy_loss.backward()
                    if self.max_grad_norm:
                        torch.nn.utils.clip_grad_norm_(self.model.instruct_lm.parameters(), self.max_grad_norm)
                    self.optim.step()

                    if self.lr_scheduler:
                        self.lr_scheduler.step()

        return RLHFTrainingStats(
            policy_loss=total_policy_losses,
        )