import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import pprint
import gym
import numpy as np
import torch as th
from torch.nn import functional as F

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule

from .off_policy_algorithm import OffPolicyAlgorithm
from .policies import BCPolicy


class BC(OffPolicyAlgorithm):
    def __init__(
        self,
        policy: Union[str, Type[BCPolicy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-4,
        buffer_size: int = 1_000_000, # 1e6
        learning_starts: int = 100,
        batch_size: int = 128,
        tau: float = 1.0,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = (1, 'episode'),
        gradient_steps: int = 1,
        offline_buffer: Optional[ReplayBuffer] = None,
        replay_buffer_class: Optional[ReplayBuffer] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        max_grad_norm: float = 10,
        tensorboard_log: Optional[str] = None,
        create_eval_env: bool = False,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
        without_exploration: bool = True, # if False -> OnlineBC
        source_model: str = None
    ): 
        super(BC, self).__init__(
            policy,
            env,
            BCPolicy,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise=None,
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            policy_kwargs=policy_kwargs,
            tensorboard_log=tensorboard_log,
            verbose=verbose,
            device=device,
            create_eval_env=create_eval_env,
            seed=seed,
            sde_support=False,
            optimize_memory_usage=optimize_memory_usage,
            supported_action_spaces=(gym.spaces.Discrete, gym.spaces.Box,),
            support_multi_env=True,
            without_exploration=without_exploration,
        )

        self.offline_buffer = offline_buffer
        self.max_grad_norm = max_grad_norm
        if _init_setup_model:
            self._setup_model()
        
        self.source_model = source_model

        if self.actor is not None:
            if self.actor.features_extractor.meta_mode:
                if self.source_model is not None:
                    source_dict = th.load(self.source_model)
                    model_dict = self.actor.state_dict()
                    temp_dict = {}
                    # print(list(model_dict.keys())[0])
                    # print(list(source_dict.keys())[11])
                    # print(list(source_dict.keys())[11][6:])
                    for k, v in source_dict.items():
                        k = k[6:]
                        temp_dict[k] = v
                    pretrained_dict = {k: v for k, v in temp_dict.items() if k in model_dict}  
                    # print(pretrained_dict.keys())
                    model_dict.update(pretrained_dict) 
                    self.actor.load_state_dict(model_dict)
                
                for_meta = [
                    "actor.0.weight", "actor.0.bias",
                    "actor.2.weight", "actor.2.bias",
                    "actor.4.weight", "actor.4.bias",
                ]
                for name, param in self.actor.named_parameters():
                    if name in for_meta:
                        param.requires_grad_(False)

            print("Turning off gradients in both the image and the text encoder")
            for name, param in self.actor.features_extractor.extractors.named_parameters():
                if "prompt" not in name:
                    param.requires_grad_(False)
            if self.actor.features_extractor.clip_model is not None:
                for name, param in self.actor.features_extractor.clip_model.named_parameters():
                    param.requires_grad_(False)
                    assert not param.requires_grad
            # Double check
            enabled = set()
            for name, param in self.actor.named_parameters():
                if param.requires_grad:
                    enabled.add(name)
            print(f"Parameters to be updated:")
            pprint.pprint(enabled)

    def _setup_model(self) -> None:
        super(BC, self)._setup_model()
        self._create_aliases()

    def _create_aliases(self) -> None:
        self.actor = self.policy.actor

    def train(self, gradient_steps: int, batch_size: int = 256) -> None:
        
        actor_losses = []
        mse_losses, neglogps = [], [] 
        for gradient_step in range(gradient_steps):
            if self.without_exploration: # Naive BC
                replay_data = self.offline_buffer.sample(batch_size, env=self._vec_normalize_env)
            else: # Online BC
                replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
            actions_pred = self.actor(replay_data.observations)

            # Compute MSE loss
            mse_loss = F.mse_loss(replay_data.actions, actions_pred).mean()
            # Compuate NegProb
            actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
            if log_prob is not None:
                neglogp = -log_prob.mean()
                entropy = -neglogp
            else:
                neglogp, entropy = 0, 0
             
            loss = mse_loss 
            
            actor_losses.append(loss.item())
            mse_losses.append(mse_loss.item())
            if log_prob is not None: 
                neglogps.append(neglogp.item())

            self.policy.optimizer.zero_grad()
            loss.backward()
            th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()
            

        self._n_updates += gradient_steps
        self.logger.record("train/n_udpates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/mse_loss", np.mean(mse_losses))
        if len(neglogps) > 0:
            self.logger.record("train/leglogp", np.mean(neglogps))

    def learn(
        self,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 1,
        eval_env: Optional[GymEnv] = None,
        eval_freq: int = -1,
        n_eval_episodes: int = 5,
        tb_log_name: str = "BC",
        eval_log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
    ) -> "BC":

        return super(BC, self).learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval, 
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
        )
    
    def _excluded_save_params(self) -> List[str]:
        return super(BC, self)._excluded_save_params() + ["actor"]

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "policy.optimizer"]
        return state_dicts, []
