import os
import torch
import torch.optim as optim
import wandb
from tianshou.data import VectorReplayBuffer, ReplayBuffer
from tianshou.trainer import OnpolicyTrainer
from trainer.base_obj import RLObjective
from trainer.RLHparams import OnPolicyRLHyperParameterSpace, LLM_HyperParams, LLM_Instruct_HyperParams
from utils.network import define_single_network, define_continuous_critic
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.common import ActorCritic
from torch.distributions import Distribution, Independent, Normal
from trainer.collector import GlucoseCollector as Collector
from trainer.base_policy import RandomPolicy
from trainer.RLPolicy import PPOPolicy, LLM_Policy, LLM_Instruct_Policy
from trainer.net import ActorLM, ActorLM_API,  InstructLM, DoubleLM
from utils.misc import set_global_seed
from utils.wandb import WandbLogger


class PPOObjective(RLObjective):
    def __init__(self, env_name, env_args, hparam_space: OnPolicyRLHyperParameterSpace, device, **kwargs):
        super().__init__(env_name, env_args, hparam_space, device, **kwargs)

    def define_policy(self, gamma, lr, gae_lambda, vf_coef, ent_coef, eps_clip, value_clip, dual_clip,
                      advantage_normalization, recompute_advantage, n_step, epoch, batch_size, obs_mode, linear, **kwargs):
        if obs_mode == "cat": cat_num, use_rnn = self.meta_param["obs_window"], False
        elif obs_mode == "stack": cat_num, use_rnn = 1, True
        else: raise NotImplementedError("obs_mode not supported")

        net_a = define_single_network(self.state_shape, self.action_shape, use_dueling=False, num_layer=3, hidden_size=128,
                                      use_rnn=use_rnn, device=self.device, cat_num=cat_num)
        actor = ActorProb(net_a, self.action_shape, unbounded=True, device=self.device, ).to(self.device)
        critic = define_continuous_critic(self.state_shape, self.action_shape, linear=linear, use_rnn=use_rnn,
                                          cat_num=cat_num, use_action_net=False, state_net_hidden_size=128, 
                                          device=self.device)
        actor_critic = ActorCritic(actor, critic)
        optim = torch.optim.Adam(actor_critic.parameters(), lr=lr)

        # torch.nn.init.constant_(actor.sigma_param, -0.5)
        # for m in actor_critic.modules():
        #     if isinstance(m, torch.nn.Linear):
        #         # orthogonal initialization
        #         torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
        #         torch.nn.init.zeros_(m.bias)
        # do last policy layer scaling, this will make initial actions have (close to)
        # 0 mean and std, and will help boost performances,
        # see https://arxiv.org/abs/2006.05990, Fig.24 for details
        for m in actor.mu.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.zeros_(m.bias)
                m.weight.data.copy_(0.01 * m.weight.data)

        def dist(*loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
            loc, scale = loc_scale
            return Independent(Normal(loc, scale), 1)

        policy: PPOPolicy = PPOPolicy(
            actor=actor,
            critic=critic,
            optim=optim,
            dist_fn=dist,
            discount_factor=gamma,
            gae_lambda=float(gae_lambda),
            vf_coef=vf_coef,
            ent_coef=ent_coef,
            max_grad_norm=1.,
            action_scaling=True,
            action_bound_method='clip',
            action_space=self.action_space,
            eps_clip=eps_clip,
            value_clip=value_clip,
            dual_clip=dual_clip,
            advantage_normalization=advantage_normalization,
            recompute_advantage=recompute_advantage,
        )
        return policy

    def run(self, policy, step_per_collect, repeat_per_collect, batch_size, start_timesteps, **kwargs):
        def save_best_fn(policy):
            torch.save(policy.state_dict(), os.path.join(self.log_path, "best_policy.pth"))

        # collector
        if self.meta_param["training_num"] > 1:
            buffer = VectorReplayBuffer(
                self.meta_param["buffer_size"],
                buffer_num=len(self.train_envs),
                ignore_obs_next=False,
                save_only_last_obs=False,
                stack_num=1
            )
        else:
            buffer = ReplayBuffer(self.meta_param["buffer_size"],
                                  ignore_obs_next=False,
                                  save_only_last_obs=False,
                                  stack_num=1)

        # collector
        train_collector = Collector(policy, self.train_envs, buffer, exploration_noise=True)
        test_collector = Collector(policy, self.test_envs, exploration_noise=False)
        if start_timesteps > 0:
            print(f"warmup with random policy for {start_timesteps} steps..")
            warmup_policy = RandomPolicy(min_act=0, max_act=2 if self.env_args["discrete"] else 0.1,
                                         action_space=self.action_space)
            warmup_collector = Collector(warmup_policy, self.train_envs, buffer, exploration_noise=True)
            warmup_collector.collect(n_step=start_timesteps)

        OnpolicyTrainer(
            policy,
            batch_size=batch_size,
            train_collector=train_collector,
            test_collector=test_collector,
            max_epoch=self.meta_param["epoch"],
            step_per_epoch=self.meta_param["step_per_epoch"],
            step_per_collect=step_per_collect,
            repeat_per_collect=repeat_per_collect,
            episode_per_test=self.meta_param["test_num"],
            train_fn=None,
            test_fn=None,
            stop_fn=self.early_stop_fn,
            save_best_fn=save_best_fn,
            logger=self.logger,
            save_checkpoint_fn=self.save_checkpoint_fn,
        ).run()

        # load the best policy to test again
        policy.load_state_dict(torch.load(os.path.join(self.log_path, "best_policy.pth")))
        return policy, lambda epoch, env_step: None


class LLM_Objective(RLObjective):
    def __init__(self, env_name, env_args, hparam_space: LLM_HyperParams, device, **kwargs):
        super().__init__(env_name, env_args, hparam_space, device=device, **kwargs)

    def define_policy(self, inference_mode, transformers_mode, llm_mode, num_try, need_summary, need_meta_info, **kwargs):
        assert inference_mode in ["API", "local"], f"inference_mode can only be API or local, got '{inference_mode}'. "
        if inference_mode == "API":
            net = ActorLM_API(llm=llm_mode["llm"], context_window=llm_mode["context_window"])
        else:
            net = ActorLM(llm=llm_mode["llm"], context_window=llm_mode["context_window"],
                            device=self.device, transformers_mode=transformers_mode,
                            model_dir="./model_hub").to(self.device)
        return LLM_Policy(
            net,
            action_space=self.action_space,
            observation_space=self.state_space,
            num_try=num_try,
            need_summary=need_summary,
            need_meta_info=need_meta_info
        )

    def wandb_search(self):
        self.logger = WandbLogger(train_interval=24 * 15)
        self.meta_param["training_num"] = 1
        self.meta_param["num_actions"] = None
        hparams = wandb.config

        self.prepare_env(int(hparams["seed"]), self.env_name, **self.env_args)
        set_global_seed(int(hparams["seed"]))

        # start training
        print("prepare policy")
        self.policy = self.define_policy(**{**hparams, **self.meta_param})

        # test on all envs
        self.test_all_patients(self.policy, None, int(hparams["seed"]), self.logger, n_episode=20)


class LLM_Instruct_Objective(RLObjective):
    def __init__(self, env_name, env_args, hparam_space: LLM_Instruct_HyperParams, device, **kwargs):
        super().__init__(env_name, env_args, hparam_space, device=device, **kwargs)
    
    def define_policy(self, inference_mode, transformers_mode, lr, gae_lambda, eps_clip, ent_coef, kl_coef,
                      actor_llm_mode, instruct_llm_mode, num_try, need_meta_info, **kwargs):
        assert inference_mode in ["API", "local"], f"inference_mode can only be API or local, got '{inference_mode}'. "
        if inference_mode == "API":
            actor_lm = ActorLM_API(llm=actor_llm_mode["llm"], context_window=actor_llm_mode["context_window"])
        else:
            actor_lm = ActorLM(llm=actor_llm_mode["llm"], context_window=actor_llm_mode["context_window"],
                            device=self.device, transformers_mode = transformers_mode,
                            model_dir="./model_hub").to(self.device)
        instruct_lm = InstructLM(llm=instruct_llm_mode["llm"], context_window=instruct_llm_mode["context_window"],
                           device=self.device,
                           model_dir="./model_hub").to(self.device)
        net = DoubleLM(actor_lm, instruct_lm)
        actor_optim = optim.Adam(net.instruct_lm.actor_parameters(), lr=lr)
        critic_optim = optim.Adam(net.instruct_lm.critic_parameters(), lr=lr)

        # clear cached memory
        del actor_lm
        del instruct_lm
        import gc
        gc.collect()
        import torch
        torch.cuda.empty_cache()

        return LLM_Instruct_Policy(
            net,
            actor_optim,
            critic_optim,
            action_space=self.action_space,
            observation_space=self.state_space,
            gae_lambda=float(gae_lambda),
            eps_clip=eps_clip,
            ent_coef=ent_coef,
            kl_coef=kl_coef,
            max_grad_norm=1.,
            num_try=num_try,
            need_meta_info=need_meta_info
        )
    
    def run(self, policy, step_per_collect, repeat_per_collect, batch_size, start_timesteps, **kwargss):
        def save_best_fn(policy):
            torch.save(policy.state_dict(), os.path.join(self.log_path, "best_policy.pth"))
        
        # collector
        if self.meta_param["training_num"] > 1:
            buffer = VectorReplayBuffer(
                self.meta_param["buffer_size"],
                buffer_num=len(self.train_envs),
                ignore_obs_next=False,
                save_only_last_obs=False,
                stack_num=1
            )
        else:
            buffer = ReplayBuffer(self.meta_param["buffer_size"],
                                  ignore_obs_next=False,
                                  save_only_last_obs=False,
                                  stack_num=1)

        # collector
        train_collector = Collector(policy, self.train_envs, buffer, exploration_noise=True)
        test_collector = Collector(policy, self.test_envs, exploration_noise=False)
        if start_timesteps > 0:
            print(f"warmup with random policy for {start_timesteps} steps..")
            warmup_policy = RandomPolicy(min_act=0, max_act=2 if self.env_args["discrete"] else 0.1,
                                         action_space=self.action_space)
            warmup_collector = Collector(warmup_policy, self.train_envs, buffer, exploration_noise=True)
            warmup_collector.collect(n_step=start_timesteps)

        OnpolicyTrainer(
            policy,
            batch_size=batch_size,
            train_collector=train_collector,
            test_collector=test_collector,
            max_epoch=self.meta_param["epoch"],
            step_per_epoch=self.meta_param["step_per_epoch"],
            step_per_collect=step_per_collect,
            repeat_per_collect=repeat_per_collect,
            episode_per_test=self.meta_param["test_num"],
            train_fn=None,
            test_fn=None,
            stop_fn=self.early_stop_fn,
            save_best_fn=save_best_fn,
            logger=self.logger,
            save_checkpoint_fn=self.save_checkpoint_fn,
        ).run()

        # load the best policy to test again
        policy.load_state_dict(torch.load(os.path.join(self.log_path, "best_policy.pth")))
        return policy, lambda epoch, env_step: None