from dataclasses import dataclass
from typing import Any, Literal, cast, Union, TypeVar

import torch
import gym
import numpy as np

from tianshou.data import Batch, SequenceSummaryStats
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import (
    ModelOutputBatchProtocol,
    ObsBatchProtocol,
    RolloutBatchProtocol,
)
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.policy.base import TLearningRateScheduler, TTrainingStats, TrainingStats
from trainer.net import ActorLM, ActorLM_API, DoubleLM
from utils.prompts import (Conversation, obs2text, text2act, SYSTEM_PROMPT, SUMMARY_INSTRUCTION_PROMPT,
                           LLM_INFERENCE_INSTRUCTION_PROMPT, LLM_INFERENCE_RETRY_PROMPT, get_patient_info_prompt)
from utils.misc import process_tensor, unzip, truncate


@dataclass(kw_only=True)
class RLHFTrainingStats(TrainingStats):
    policy_loss: SequenceSummaryStats
    value_loss: SequenceSummaryStats


TRLHFTrainingStats = TypeVar("TRLHFTrainingStats", bound=RLHFTrainingStats)


class LLM_Policy(BasePolicy):
    """
    Implementation of pure LLM policy.
    """

    def __init__(
            self,
            model: Union[ActorLM, ActorLM_API],
            action_space: gym.Space,
            observation_space: gym.Space | None = None,
            need_summary: bool = False,
            need_meta_info: bool = False,
            num_try: int = 1,
    ) -> None:
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            action_scaling=False,
            action_bound_method=None,
        )
        self.model = model
        self.need_summary = need_summary
        self.need_meta_info = need_meta_info
        self.num_try = num_try
        if num_try < 1:
            raise ValueError("num_try should be greater than 0")

        self.meta_info_fn = get_patient_info_prompt if need_meta_info else lambda *args: ""

    def forward(
            self,
            batch: ObsBatchProtocol,
            state: dict | BatchProtocol | np.ndarray | None = None,
            model: Literal["model", "model_old"] = "model",
            **kwargs: Any,
    ) -> ModelOutputBatchProtocol:
        """Decide action over the given batch data."""
        if batch.obs.shape[0] != 1:
            raise ValueError("LLM_Policy only supports batch size of 1 at inference time.")
        model = getattr(self, model)

        obs_prompt = obs2text(batch[0])
        meta_prompt = self.meta_info_fn(batch[0].info["Age"],
                                        batch[0].info["CR"],
                                        batch[0].info["CF"],
                                        batch[0].info["TDI"])
        messages = Conversation()
        messages.insert_component("system", SYSTEM_PROMPT + meta_prompt, 0)
        if self.need_summary and (batch.obs[:, :, 0] == -1).mean() < 0.8:
            messages.insert_component("user", obs_prompt + SUMMARY_INSTRUCTION_PROMPT, -1)
            summary = model.forward(messages.get())
            messages.insert_component("assistant", summary, -1)
            messages.insert_component("user", LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
            action_text = model.forward(messages.get())
        else:
            messages.insert_component("user", obs_prompt + LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
            action_text = model.forward(messages.get())

        use_random = True
        for _ in range(self.num_try):
            act = text2act(action_text, self.action_space)
            if act is not None:
                use_random = False
                break
            messages.insert_component("assistant", action_text, -1)
            messages.insert_component("user", LLM_INFERENCE_RETRY_PROMPT, -1)
            action_text = model.forward(messages.get())

        # use random action if no valid action is found
        if use_random:
            act = self.action_space.sample()

        result = Batch(act=[act], state=state)
        return cast(ModelOutputBatchProtocol, result)

    def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TTrainingStats:
        raise NotImplementedError("LLM_Policy does not support learning.")


class LLM_Instruct_Policy(LLM_Policy):
    """
    Implementation of LLM Instruct policy.
    """

    def __init__(
            self,
            model: DoubleLM,
            actor_optim: torch.optim.Optimizer,
            critic_optim: torch.optim.Optimizer,
            action_space: gym.Space,
            eps_clip: float = 0.2,
            vf_coef: float = 0.5,
            ent_coef: float = 0.01,
            kl_coef: float = 0.01,
            max_grad_norm: float | None = None,
            gamma: float = 1,
            gae_lambda: float = 0.95,
            observation_space: gym.Space | None = None,
            need_meta_info: bool = False,
            num_try: int = 1,
            lr_scheduler: TLearningRateScheduler | None = None,
    ) -> None:
        super().__init__(
            model=model,
            action_space=action_space,
            observation_space=observation_space,
            need_meta_info=need_meta_info,
            num_try=num_try,
        )
        assert 0.0 <= gamma <= 1.0, f"Discount factor gamma should be in [0, 1] but got: {gamma}"
        self.gamma = gamma
        assert 0.0 <= gae_lambda <= 1.0, f"GAE lambda should be in [0, 1] but got: {gae_lambda}"
        self.gae_lambda = gae_lambda
        self.eps_clip = eps_clip
        self.ent_coef = ent_coef
        self.kl_coef = kl_coef
        self.max_grad_norm = max_grad_norm

        self.actor_optim = actor_optim
        self.critic_optim = critic_optim
        self.lr_scheduler = lr_scheduler

        # freeze all model parameters except instruct_lm
        for name, param in self.model.named_parameters():
            if "instruct_lm" not in name:
                param.requires_grad = False
    
    def kl_divergence(self, old_log_probs, new_log_probs):
        kl_div = torch.sum(torch.exp(old_log_probs) * (old_log_probs - new_log_probs), dim=1)
        return kl_div
    
    def entropy(self, log_probs):
        ent = -torch.sum(log_probs * torch.exp(log_probs), dim=1)
        return ent

    def get_adv(self, rewards, old_values):
        returns = []
        advantages = []
        last_gae_lambda = 0
        last_value = 0
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * last_value - old_values[step]
            last_gae_lambda = delta + self.gamma * self.gae_lambda * last_gae_lambda
            advantages.insert(0, last_gae_lambda)
            last_value = old_values[step]
            returns.insert(0, last_gae_lambda + old_values[step])

        # normalize advantages
        advantages = torch.tensor(advantages, dtype=torch.float32).to(rewards.device)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        returns = torch.tensor(returns, dtype=torch.float32).to(rewards.device)

        return advantages, returns
    
    def forward(
            self,
            batch: ObsBatchProtocol,
            state: dict | BatchProtocol | np.ndarray | None = None,
            model: Literal["model", "model_old"] = "model",
            **kwargs: Any,
    ) -> ModelOutputBatchProtocol:
        """Decide action over the given batch data."""
        if batch.obs.shape[0] != 1:
            raise ValueError("LLM_Policy only supports batch size of 1 at inference time.")
        model = getattr(self, model)
        log_probs, vals = torch.empty((512, 2 * self.model.instruct_lm.top_k)), torch.zeros((512, 1))

        obs_prompt = obs2text(batch[0])
        meta_prompt = self.meta_info_fn(batch[0].info["Age"],
                                        batch[0].info["CR"],
                                        batch[0].info["CF"],
                                        batch[0].info["TDI"])
        messages = Conversation()
        messages.insert_component("system", SYSTEM_PROMPT + meta_prompt, 0)
        if (batch.obs[:, :, 0] == -1).mean() < 0.8:
            messages.insert_component("user", obs_prompt + SUMMARY_INSTRUCTION_PROMPT, -1)
            summary, log_probs, vals = model.forward(messages.get(), mode='instruct')
            messages.insert_component("assistant", summary, -1)
            messages.insert_component("user", LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
            action_text = model.forward(messages.get(), mode='actor')
        else:
            messages.insert_component("user", obs_prompt + LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
            action_text = model.forward(messages.get(), mode='actor')

        use_random = True
        for _ in range(self.num_try):
            act = text2act(action_text, self.action_space)
            if act is not None:
                use_random = False
                break
            messages.insert_component("assistant", action_text, -1)
            messages.insert_component("user", LLM_INFERENCE_RETRY_PROMPT, -1)
            action_text = model.forward(messages.get(), mode='actor')

        # use random action if no valid action is found
        if use_random:
            act = self.action_space.sample()

        result = Batch(act=[act], state=state, log_probs=[log_probs], vals=[vals])
        return cast(ModelOutputBatchProtocol, result)

    def learn(  # type: ignore
        self,
        batch: RolloutBatchProtocol,
        batch_size: int | None,
        repeat: int,
        *args: Any,
        **kwargs: Any,
    ) -> TTrainingStats:
        """Learn function to update the InstructLM using PPO based on the rewards from the batch."""
        total_policy_losses, total_value_losses = [], []
        split_batch_size = batch_size or -1

        for _ in range(repeat):
            for mini_batch in batch.split(split_batch_size, merge_last=True):
                policy_losses, value_losses, = [], []
                for bat in mini_batch:
                    old_log_probs, old_values = process_tensor(bat.log_probs), process_tensor(bat.vals)
                    if old_log_probs == None or old_values == None: # in the first few steps of an episode, no summaries
                        continue

                    # get new model output
                    messages = Conversation()
                    obs_prompt = obs2text(bat)
                    messages.insert_component("system", SYSTEM_PROMPT, 0)
                    messages.insert_component("user", obs_prompt + LLM_INFERENCE_INSTRUCTION_PROMPT, -1)
                    _, new_log_probs, new_values = self.model.forward(messages.get(), mode='instruct')
                    new_log_probs, new_values = process_tensor(new_log_probs), process_tensor(new_values)
                    if new_log_probs == None or new_values == None: # in the first few steps of an episode, no summaries
                        continue
                    new_log_probs, old_log_probs = truncate(unzip(new_log_probs, self.model.instruct_lm.top_k, self.model.instruct_lm.vocab_size),
                                                            unzip(old_log_probs, self.model.instruct_lm.top_k, self.model.instruct_lm.vocab_size))
                    new_values, old_values = truncate(new_values, old_values)

                    # compute rewards
                    rewards = torch.zeros(len(old_log_probs))
                    rewards[-1] = bat.rew.item()  # credit assigned to last token
                    kl_penalty = self.kl_divergence(old_log_probs, new_log_probs) * self.kl_coef
                    rewards = rewards - kl_penalty

                    # compute policy loss
                    advantages, returns = self.get_adv(rewards, old_values)
                    ratio = torch.exp(new_log_probs - old_log_probs)
                    ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                    surr1 = ratio * advantages
                    surr2 = torch.clamp(ratio, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * advantages
                    entropies = self.entropy(new_log_probs)
                    policy_loss = (-torch.min(surr1, surr2) - self.ent_coef * entropies).mean()
                    policy_losses.append(policy_loss.item())

                    # compute value loss
                    value_loss = (returns - new_values).pow(2).mean()
                    value_losses.append(value_loss.item())
                
                if (policy_losses != []) and (value_losses != []):
                    self.actor_optim.zero_grad()
                    policy_loss = torch.tensor(policy_losses, requires_grad=True).mean()
                    policy_loss.backward()
                    if self.max_grad_norm:
                        torch.nn.utils.clip_grad_norm_(self.model.instruct_lm.actor_parameters(), self.max_grad_norm)
                    self.actor_optim.step()

                    self.critic_optim.zero_grad()
                    value_loss = torch.tensor(value_losses, requires_grad=True).mean()
                    value_loss.backward()
                    if self.max_grad_norm:
                        torch.nn.utils.clip_grad_norm_(self.model.instruct_lm.critic_parameters(), self.max_grad_norm)
                    self.critic_optim.step()

                    total_policy_losses.append(policy_loss.item())
                    total_value_losses.append(value_loss.item())

                    if self.lr_scheduler:
                        self.lr_scheduler.step()

        return RLHFTrainingStats(
            policy_loss=total_policy_losses,
            value_loss=total_value_losses,
        )