import os
import torch
import torch.optim as optim
import wandb
from tianshou.data import VectorReplayBuffer, ReplayBuffer
from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer
from trainer.base_obj import RLObjective
from trainer.RLHparams import OffPolicyRLHyperParameterSpace, OnPolicyRLHyperParameterSpace, LLM_HyperParams, LLM_Instruct_HyperParams
from utils.network import define_single_network, define_continuous_critic, Net
from tianshou.utils.net.continuous import Critic, ActorProb
from tianshou.utils.net.common import ActorCritic
from torch.distributions import Distribution, Independent, Normal
from trainer.GRPOTrainer import GRPOTrainer
from trainer.collector import GlucoseCollector as Collector
from trainer.base_policy import RandomPolicy
from trainer.RLPolicy import DQNPolicy, PPOPolicy, SACPolicy, LLM_Policy, LLM_Instruct_Policy
from trainer.net import ActorLM, ActorLM_API,  InstructLM, DoubleLM
from utils.misc import set_global_seed
from utils.wandb import WandbLogger


class DQNObjective(RLObjective):
    def __init__(self, env_name, env_args, hparam_space: OffPolicyRLHyperParameterSpace, device, **kwargs):
        super().__init__(env_name, env_args, hparam_space, device, **kwargs)

    def define_policy(self,
                      # general hp
                      gamma,
                      lr,
                      obs_mode,
                      linear,

                      # dqn hp
                      n_step,
                      target_update_freq,
                      is_double,
                      use_dueling,
                      *args,
                      **kwargs
                      ):
        # define model
        if obs_mode == "cat": cat_num, use_rnn = self.meta_param["obs_window"], False
        elif obs_mode == "stack": cat_num, use_rnn = 1, True
        else: raise NotImplementedError("obs_mode not supported")

        net = define_single_network(self.state_shape, self.action_shape, use_dueling=use_dueling,
                                    use_rnn=use_rnn, device=self.device, linear=linear, cat_num=cat_num)
        optim = torch.optim.Adam(net.parameters(), lr=lr)
        # define policy
        policy = DQNPolicy(
            model=net,
            optim=optim,
            discount_factor=gamma,
            estimation_step=n_step,
            target_update_freq=target_update_freq,
            is_double=is_double,  # we will have a separate runner for double dqn
            action_space=self.action_space,
            observation_space=self.state_space,
            clip_loss_grad=True,
        )
        return policy

    def run(self, policy,
            eps_test,
            eps_train,
            eps_train_final,
            step_per_collect,
            update_per_step,
            batch_size,
            start_timesteps,
            **kwargs
            ):
        def save_best_fn(policy):
            torch.save(policy.state_dict(), os.path.join(self.log_path, "best_policy.pth"))

        def train_fn(epoch, env_step):
            # nature DQN setting, linear decay in the first 10k steps
            if env_step <= self.meta_param["epoch"] * self.meta_param["step_per_epoch"] * 0.95:
                eps = eps_train - env_step / (self.meta_param["epoch"] * self.meta_param["step_per_epoch"] * 0.95) * \
                      (eps_train - eps_train_final)
            else:
                eps = eps_train_final
            policy.set_eps(eps)
            if env_step % 1000 == 0:
                self.logger.write("train/env_step", env_step, {"train/eps": eps})

        def test_fn(epoch, env_step):
            policy.set_eps(eps_test)

        # replay buffer: `save_last_obs` and `stack_num` can be removed together
        # when you have enough RAM
        if self.meta_param["training_num"] > 1:
            buffer = VectorReplayBuffer(
                self.meta_param["buffer_size"],
                buffer_num=len(self.train_envs),
                ignore_obs_next=False,
                save_only_last_obs=False,
                stack_num=1  # stack is implemented in the env
            )
        else:
            buffer = ReplayBuffer(self.meta_param["buffer_size"],
                                  ignore_obs_next=False,
                                  save_only_last_obs=False,
                                  stack_num=1)
        # collector
        train_collector = Collector(policy, self.train_envs, buffer, exploration_noise=True)
        test_collector = Collector(policy, self.test_envs, exploration_noise=True)

        if start_timesteps > 0:
            print(f"start to warmup with random policy for {start_timesteps} steps..")
            train_collector.collect(n_step=start_timesteps, random=True)

        OffpolicyTrainer(
            policy,
            max_epoch=self.meta_param["epoch"],
            batch_size=batch_size,
            train_collector=train_collector,
            test_collector=test_collector,
            step_per_epoch=self.meta_param["step_per_epoch"],
            step_per_collect=step_per_collect,
            episode_per_test=self.meta_param["test_num"],
            train_fn=train_fn,
            test_fn=test_fn,
            stop_fn=self.early_stop_fn,
            save_best_fn=save_best_fn,
            logger=self.logger,
            update_per_step=update_per_step,
            save_checkpoint_fn=self.save_checkpoint_fn,
        ).run()

        # load the best policy to test again
        policy.load_state_dict(torch.load(os.path.join(self.log_path, "best_policy.pth")))
        return policy, test_fn


class PPOObjective(RLObjective):
    def __init__(self, env_name, env_args, hparam_space: OnPolicyRLHyperParameterSpace, device, **kwargs):
        super().__init__(env_name, env_args, hparam_space, device, **kwargs)

    def define_policy(self, gamma, lr, gae_lambda, vf_coef, ent_coef, eps_clip, value_clip, dual_clip,
                      advantage_normalization, recompute_advantage, obs_mode, linear, **kwargs):
        if obs_mode == "cat": cat_num, use_rnn = self.meta_param["obs_window"], False
        elif obs_mode == "stack": cat_num, use_rnn = 1, True
        else: raise NotImplementedError("obs_mode not supported")

        net_a = define_single_network(self.state_shape, self.action_shape, use_dueling=False, num_layer=3, hidden_size=128,
                                      use_rnn=use_rnn, device=self.device, cat_num=cat_num)
        actor = ActorProb(net_a, self.action_shape, unbounded=True, device=self.device, ).to(self.device)
        critic = define_continuous_critic(self.state_shape, self.action_shape, linear=linear, use_rnn=use_rnn,
                                          cat_num=cat_num, use_action_net=False, state_net_hidden_size=128, 
                                          device=self.device)
        actor_critic = ActorCritic(actor, critic)
        optim = torch.optim.Adam(actor_critic.parameters(), lr=lr)

        # torch.nn.init.constant_(actor.sigma_param, -0.5)
        # for m in actor_critic.modules():
        #     if isinstance(m, torch.nn.Linear):
        #         # orthogonal initialization
        #         torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
        #         torch.nn.init.zeros_(m.bias)
        # do last policy layer scaling, this will make initial actions have (close to)
        # 0 mean and std, and will help boost performances,
        # see https://arxiv.org/abs/2006.05990, Fig.24 for details
        for m in actor.mu.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.zeros_(m.bias)
                m.weight.data.copy_(0.01 * m.weight.data)

        def dist(*loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
            loc, scale = loc_scale
            return Independent(Normal(loc, scale), 1)

        policy: PPOPolicy = PPOPolicy(
            actor=actor,
            critic=critic,
            optim=optim,
            dist_fn=dist,
            discount_factor=gamma,
            gae_lambda=float(gae_lambda),
            vf_coef=vf_coef,
            ent_coef=ent_coef,
            max_grad_norm=1.,
            action_scaling=True,
            action_bound_method='clip',
            action_space=self.action_space,
            eps_clip=eps_clip,
            value_clip=value_clip,
            dual_clip=dual_clip,
            advantage_normalization=advantage_normalization,
            recompute_advantage=recompute_advantage,
        )
        return policy

    def run(self, policy, step_per_collect, repeat_per_collect, batch_size, start_timesteps, **kwargs):
        def save_best_fn(policy):
            torch.save(policy.state_dict(), os.path.join(self.log_path, "best_policy.pth"))

        # collector
        if self.meta_param["training_num"] > 1:
            buffer = VectorReplayBuffer(
                self.meta_param["buffer_size"],
                buffer_num=len(self.train_envs),
                ignore_obs_next=False,
                save_only_last_obs=False,
                stack_num=1
            )
        else:
            buffer = ReplayBuffer(self.meta_param["buffer_size"],
                                  ignore_obs_next=False,
                                  save_only_last_obs=False,
                                  stack_num=1)

        # collector
        train_collector = Collector(policy, self.train_envs, buffer, exploration_noise=True)
        test_collector = Collector(policy, self.test_envs, exploration_noise=False)
        if start_timesteps > 0:
            print(f"warmup with random policy for {start_timesteps} steps..")
            warmup_policy = RandomPolicy(min_act=0, max_act=2 if self.env_args["discrete"] else 0.1,
                                         action_space=self.action_space)
            warmup_collector = Collector(warmup_policy, self.train_envs, buffer, exploration_noise=True)
            warmup_collector.collect(n_step=start_timesteps)

        OnpolicyTrainer(
            policy,
            batch_size=batch_size,
            train_collector=train_collector,
            test_collector=test_collector,
            max_epoch=self.meta_param["epoch"],
            step_per_epoch=self.meta_param["step_per_epoch"],
            step_per_collect=step_per_collect,
            repeat_per_collect=repeat_per_collect,
            episode_per_test=self.meta_param["test_num"],
            train_fn=None,
            test_fn=None,
            stop_fn=self.early_stop_fn,
            save_best_fn=save_best_fn,
            logger=self.logger,
            save_checkpoint_fn=self.save_checkpoint_fn,
        ).run()

        # load the best policy to test again
        policy.load_state_dict(torch.load(os.path.join(self.log_path, "best_policy.pth")))
        return policy, lambda epoch, env_step: None


class SACObjective(RLObjective):
    def __init__(self, env_name, hparam_space: OffPolicyRLHyperParameterSpace, device,
                 **kwargs):
        super().__init__(env_name, hparam_space, device, **kwargs)

    def define_policy(self,
                      gamma,
                      lr,
                      alpha,
                      n_step,
                      tau,
                      obs_mode,
                      **kwargs, ):
        hidden_sizes = [256, 256, 256]
        if obs_mode == "cat": cat_num, use_rnn = self.meta_param["obs_window"], False
        elif obs_mode == "stack": cat_num, use_rnn = 1, True
        else: raise NotImplementedError("obs_mode not supported")

        # model
        net_a = Net(self.state_shape, hidden_sizes=hidden_sizes, device=self.device, cat_num=cat_num)
        actor = ActorProb(
            net_a,
            self.action_shape,
            device=self.device,
            unbounded=True,
            conditioned_sigma=True,
        ).to(self.device)
        actor_optim = torch.optim.Adam(actor.parameters(), lr=lr)
        net_c1 = Net(
            self.state_shape,
            self.action_shape,
            hidden_sizes=hidden_sizes,
            concat=True,
            device=self.device,
            cat_num=cat_num
        )
        net_c2 = Net(
            self.state_shape,
            self.action_shape,
            hidden_sizes=hidden_sizes,
            concat=True,
            device=self.device,
            cat_num=cat_num
        )
        critic1 = Critic(net_c1, device=self.device).to(self.device)
        critic1_optim = torch.optim.Adam(critic1.parameters(), lr=lr)
        critic2 = Critic(net_c2, device=self.device).to(self.device)
        critic2_optim = torch.optim.Adam(critic2.parameters(), lr=lr)

        policy = SACPolicy(
            actor,
            actor_optim,
            critic1,
            critic1_optim,
            critic2,
            critic2_optim,
            tau=tau,
            gamma=gamma,
            alpha=alpha,
            estimation_step=n_step,
            action_space=self.action_space,
        )
        return policy

    def run(self, policy,
            stack_num,
            cat_num,
            step_per_collect,
            update_per_step,
            batch_size,
            start_timesteps,
            **kwargs):
        assert not (cat_num > 1 and stack_num > 1), "does not support both categorical and frame stack"
        stack_num = max(stack_num, cat_num)
        # collector
        if self.meta_param["training_num"] > 1:
            buffer = VectorReplayBuffer(self.meta_param["buffer_size"], len(self.train_envs), stack_num=stack_num)
        else:
            buffer = ReplayBuffer(self.meta_param["buffer_size"], stack_num=stack_num)
        train_collector = Collector(policy, self.train_envs, buffer, exploration_noise=True)
        test_collector = Collector(policy, self.train_envs)
        if start_timesteps > 0:
            train_collector.collect(n_step=start_timesteps, random=True)

        def save_best_fn(policy):
            torch.save(policy.state_dict(), os.path.join(self.log_path, "best_policy.pth"))

        OffpolicyTrainer(
            policy,
            train_collector,
            test_collector,
            self.meta_param["epoch"],
            self.meta_param["step_per_epoch"],
            step_per_collect,
            self.meta_param["test_num"],
            batch_size,
            train_fn=None,
            test_fn=None,
            save_best_fn=save_best_fn,
            logger=self.logger,
            update_per_step=update_per_step,
            stop_fn=self.early_stop_fn,
            save_checkpoint_fn=self.save_checkpoint_fn
        ).run()

        # load the best policy to test again
        policy.load_state_dict(torch.load(os.path.join(self.log_path, "best_policy.pth")))
        return policy, lambda epoch, env_step: None


class LLM_Objective(RLObjective):
    def __init__(self, env_name, env_args, hparam_space: LLM_HyperParams, device, **kwargs):
        super().__init__(env_name, env_args, hparam_space, device=device, **kwargs)

    def define_policy(self, inference_mode, transformers_mode, llm_mode, num_try, need_summary, need_meta_info, **kwargs):
        assert inference_mode in ["API", "local"], f"inference_mode can only be API or local, got '{inference_mode}'. "
        if inference_mode == "API":
            net = ActorLM_API(llm=llm_mode["llm"], context_window=llm_mode["context_window"])
        else:
            net = ActorLM(llm=llm_mode["llm"], context_window=llm_mode["context_window"],
                            device=self.device, transformers_mode=transformers_mode,
                            model_dir="./model_hub").to(self.device)
        return LLM_Policy(
            net,
            action_space=self.action_space,
            observation_space=self.state_space,
            num_try=num_try,
            need_summary=need_summary,
            need_meta_info=need_meta_info
        )

    def wandb_search(self):
        self.logger = WandbLogger(train_interval=24 * 15)
        self.meta_param["training_num"] = 1
        self.meta_param["num_actions"] = None
        hparams = wandb.config

        self.prepare_env(int(hparams["seed"]), self.env_name, **self.env_args)
        set_global_seed(int(hparams["seed"]))

        # start training
        print("prepare policy")
        self.policy = self.define_policy(**{**hparams, **self.meta_param})

        # test on all envs
        self.test_all_patients(self.policy, None, int(hparams["seed"]), self.logger, n_episode=20)


class LLM_Instruct_Objective(RLObjective):
    def __init__(self, env_name, env_args, hparam_space: LLM_Instruct_HyperParams, device, **kwargs):
        super().__init__(env_name, env_args, hparam_space, device=device, **kwargs)
    
    def define_policy(self, inference_mode, transformers_mode, lr, gae_lambda, eps_clip, ent_coef, kl_coef,
                      actor_llm_mode, instruct_llm_mode, num_try, need_meta_info, **kwargs):
        assert inference_mode in ["API", "local"], f"inference_mode can only be API or local, got '{inference_mode}'. "
        if inference_mode == "API":
            actor_lm = ActorLM_API(llm=actor_llm_mode["llm"], context_window=actor_llm_mode["context_window"])
        else:
            actor_lm = ActorLM(llm=actor_llm_mode["llm"], context_window=actor_llm_mode["context_window"],
                            device=self.device, transformers_mode = transformers_mode,
                            model_dir="./model_hub").to(self.device)
        instruct_lm = InstructLM(llm=instruct_llm_mode["llm"], context_window=instruct_llm_mode["context_window"],
                           device=self.device,
                           model_dir="./model_hub").to(self.device)
        net = DoubleLM(actor_lm, instruct_lm)
        optimizer = optim.Adam(net.instruct_lm.parameters(), lr=lr)

        # clear cached memory
        del actor_lm
        del instruct_lm
        import gc
        gc.collect()
        import torch
        torch.cuda.empty_cache()

        return LLM_Instruct_Policy(
            net,
            optimizer,
            action_space=self.action_space,
            observation_space=self.state_space,
            gae_lambda=float(gae_lambda),
            eps_clip=eps_clip,
            ent_coef=ent_coef,
            kl_coef=kl_coef,
            max_grad_norm=1.,
            num_try=num_try,
            need_meta_info=need_meta_info
        )
    
    def run(self, policy, step_per_collect, repeat_per_collect, batch_size, start_timesteps, group_num, **kwargss):
        def save_best_fn(policy):
            torch.save(policy.state_dict(), os.path.join(self.log_path, "best_policy.pth"))
        
        # collector
        if self.meta_param["training_num"] > 1:
            buffer = VectorReplayBuffer(
                self.meta_param["buffer_size"],
                buffer_num=len(self.train_envs),
                ignore_obs_next=False,
                save_only_last_obs=False,
                stack_num=1
            )
        else:
            buffer = ReplayBuffer(self.meta_param["buffer_size"],
                                  ignore_obs_next=False,
                                  save_only_last_obs=False,
                                  stack_num=1)

        # collector
        train_collector = Collector(policy, self.train_envs, buffer, exploration_noise=True)
        test_collector = Collector(policy, self.test_envs, exploration_noise=False)
        if start_timesteps > 0:
            print(f"warmup with random policy for {start_timesteps} steps..")
            warmup_policy = RandomPolicy(min_act=0, max_act=2 if self.env_args["discrete"] else 0.1,
                                         action_space=self.action_space)
            warmup_collector = Collector(warmup_policy, self.train_envs, buffer, exploration_noise=True)
            warmup_collector.collect(n_step=start_timesteps)

        GRPOTrainer(
            policy,
            batch_size=batch_size,
            group_num=group_num,
            train_collector=train_collector,
            test_collector=test_collector,
            max_epoch=self.meta_param["epoch"],
            step_per_epoch=self.meta_param["step_per_epoch"],
            step_per_collect=step_per_collect,
            repeat_per_collect=repeat_per_collect,
            episode_per_test=self.meta_param["test_num"],
            train_fn=None,
            test_fn=None,
            stop_fn=self.early_stop_fn,
            #save_best_fn=save_best_fn,
            logger=self.logger,
            save_checkpoint_fn=self.save_checkpoint_fn,
        ).run()

        # load the best policy to test again
        policy.load_state_dict(torch.load(os.path.join(self.log_path, "best_policy.pth")))
        return policy, lambda epoch, env_step: None