import datetime
import random
from rl_algorithm.ddqn.replay_memory import ReplayMemory
from rl_algorithm.ddqn.model import DDQN
from rl_algorithm.ddqn.config import batch_size, discount
import numpy as np
import utils
import wandb, pdb
import torch
import torch.nn as nn
from torch.distributions import Bernoulli
import numpy as np

class DDQNAgent:
    """
    The Deep Q Learning algorithm
    """
    def __init__(
        self,
        env,
        eval_env,
        device,
        preprocess_obs,
        model_dir,
        args,
        env_size
    ):
        
        
        
        # wandb.init(project="SISReL-MiniGrid-Exploration-v")
        date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
        # default_model_name = "{}_{}_{}_{}".format(
            # args.env, args.algorithm, args.exploration_type, date
        # )
        default_model_name = "{}_{}_{}".format(
            args.env, args.algorithm, args.exploration_type
        )
        model_name = args.model or default_model_name
        
        if args.reset_ver < 0:
            args.reset_multi = 1
        
        
        self.no_reset = args.no_reset

        self.env = env
        self.eval_env = eval_env
        self.env_size = env_size
        self.reset_ver = args.reset_ver
        self.reset_itv = args.reset_itv
        self.reset_rr = args.reset_rr
        self.reset_multi = args.reset_multi
        if self.reset_multi > 1:
            self.reset_time = -1
            self.reset_ww = args.reset_ww
        self.lr = args.lr
        obs_space, _ = utils.get_obss_preprocessor(env.observation_space)
        enable_mission = utils.check_run.enable_mission(args.env)
        self.policy_network = DDQN(obs_space, env.action_space, True, enable_mission).to(device)
        self.target_network = DDQN(obs_space, env.action_space, True, enable_mission).to(device)
        
        if self.reset_multi == 2:
            self.policy_network2 = DDQN(obs_space, env.action_space, True, enable_mission).to(device)
            self.target_network2 = DDQN(obs_space, env.action_space, True, enable_mission).to(device)
        
        elif self.reset_multi == 4:
            self.policy_network2 = DDQN(obs_space, env.action_space, True, enable_mission).to(device)
            self.target_network2 = DDQN(obs_space, env.action_space, True, enable_mission).to(device)
            self.policy_network3 = DDQN(obs_space, env.action_space, True, enable_mission).to(device)
            self.target_network3 = DDQN(obs_space, env.action_space, True, enable_mission).to(device)
            self.policy_network4 = DDQN(obs_space, env.action_space, True, enable_mission).to(device)
            self.target_network4 = DDQN(obs_space, env.action_space, True, enable_mission).to(device)

        self.memory = ReplayMemory(args.max_memory, preprocess_obs)
        
        utils.common_init.init(
            self, env=env, preprocess_obs=preprocess_obs, args=args, train_interval=args.train_interval
        )
        utils.common_init.init_exploration(self, args=args)
        utils.common_init.init_log(self, model_dir=model_dir)



    def collect_experiences(
        self,
        start_time,
        episode,
        num_frames,
        return_per_frame_,
        test_return_per_frame_,
    ):
        obs = self.env.reset()[0]
        preprocessed_obs = self.preprocess_obs([obs], device=self.device)
        done = False
        self.n = 0

        
        log_loss, log_reward = [], []


        episode_step = 0
        while not done and episode_step < self.max_episode_length:
            episode_step += 1
            preprocessed_obs = self.preprocess_obs([obs], device=self.device)

            action, _ = utils.action.select_action(
                self,
                exploration_type=self.exploration_type,
                preprocessed_obs=preprocessed_obs,
                num_frames=num_frames,
            )

            
            new_obs, reward, done, _, _ = self.env.step(action)
            new_preprocessed_obs = self.preprocess_obs([new_obs], device=self.device)
            reward = reward * 10



            done_mask = 0.0 if done else 1.0
            self.memory.add(
                {
                    "step": num_frames,
                    "obs": obs,
                    "action": action,
                    "reward": reward,
                    "new_obs": new_obs,
                    "done": done_mask
                }
            )

            obs = new_obs
            num_frames += 1

            # print log
            log_reward.append(reward)

            if num_frames % self.train_interval == 0 and len(self.memory) >= batch_size:
                if self.reset_ver >= 0:
                    if num_frames % self.reset_itv == 0:
                        if self.reset_multi < 2:
                            self.policy_network.reinit_weight(ver=self.reset_ver)
                            self.target_network.reinit_weight(ver=self.reset_ver)
                            self.optimizer = torch.optim.RMSprop(self.policy_network.parameters(), self.lr)
                            for _ in range(20):
                                print("################################ reset paramters ... !!!!!")
                            # pdb.set_trace()
                        elif self.reset_multi == 2:
                            self.reset_time = (self.reset_time + 1) % 2
                            if self.reset_time == 0 and self.no_reset == 0:
                                self.policy_network.reinit_weight(ver=self.reset_ver)
                                self.target_network.reinit_weight(ver=self.reset_ver)
                                self.optimizer = torch.optim.RMSprop(self.policy_network.parameters(), self.lr)
                                for _ in range(20):
                                    print("################################ reset paramters ... 11111111111 now reset time is ", self.reset_time)
                                 
                            elif self.reset_time == 1 and self.no_reset == 0:
                                self.policy_network2.reinit_weight(ver=self.reset_ver)
                                self.target_network2.reinit_weight(ver=self.reset_ver)
                                self.optimizer2 = torch.optim.RMSprop(self.policy_network2.parameters(), self.lr)
                                for _ in range(20):
                                    print("################################ reset paramters ... 22222222222 now reset time is ", self.reset_time)
                        elif self.reset_multi == 4:
                            self.reset_time = (self.reset_time + 1) % 4
                            
                            if self.reset_time == 0 and self.no_reset == 0:
                                self.policy_network.reinit_weight(ver=self.reset_ver)
                                self.target_network.reinit_weight(ver=self.reset_ver)
                                self.optimizer = torch.optim.RMSprop(self.policy_network.parameters(), self.lr)
                                for _ in range(20):
                                    print("################################ reset paramters ... 11111111111 now reset time is ", self.reset_time)
                                    
                            elif self.reset_time == 1 and self.no_reset == 0:
                                self.policy_network2.reinit_weight(ver=self.reset_ver)
                                self.target_network2.reinit_weight(ver=self.reset_ver)
                                self.optimizer2 = torch.optim.RMSprop(self.policy_network2.parameters(), self.lr)
                                for _ in range(20):
                                    print("################################ reset paramters ... 22222222222 now reset time is ", self.reset_time)
                            elif self.reset_time == 2 and self.no_reset == 0:
                                self.policy_network3.reinit_weight(ver=self.reset_ver)
                                self.target_network3.reinit_weight(ver=self.reset_ver)
                                self.optimizer3 = torch.optim.RMSprop(self.policy_network3.parameters(), self.lr)
                                for _ in range(20):
                                    print("################################ reset paramters ... 22222222222 now reset time is ", self.reset_time)
                            elif self.reset_time == 3 and self.no_reset == 0:
                                self.policy_network4.reinit_weight(ver=self.reset_ver)
                                self.target_network4.reinit_weight(ver=self.reset_ver)
                                self.optimizer4 = torch.optim.RMSprop(self.policy_network4.parameters(), self.lr)
                                for _ in range(20):
                                    print("################################ reset paramters ... 22222222222 now reset time is ", self.reset_time)
                        

                for _ in range(self.reset_rr):
                    collected_experience = self.memory.sample(batch_size)
                    loss = self.train(collected_experience)
                    if type(loss) is float:
                        log_loss.append(loss)
                    

            self.test(num_frames, test_return_per_frame_)

        logs = {"num_frames": num_frames, "rewards": log_reward, "loss": log_loss}
        self.logs = logs
        return logs

    def train(self, collected_experience):
        if self.learn_step_counter % self.update_target == 0:
            self.update_target_network()

        loss = DDQN.train_model(
            online_net=self.policy_network,
            target_net=self.target_network,
            optimizer=self.optimizer,
            collected_experience=collected_experience,
            is_rnd=False,
        )
        if self.reset_multi == 2:
            loss2 = DDQN.train_model(
            online_net=self.policy_network2,
            target_net=self.target_network2,
            optimizer=self.optimizer2,
            collected_experience=collected_experience,
            is_rnd=False,
            )
        elif self.reset_multi == 4:
            loss2 = DDQN.train_model(
            online_net=self.policy_network2,
            target_net=self.target_network2,
            optimizer=self.optimizer2,
            collected_experience=collected_experience,
            is_rnd=False,
            )
            loss3 = DDQN.train_model(
            online_net=self.policy_network3,
            target_net=self.target_network3,
            optimizer=self.optimizer3,
            collected_experience=collected_experience,
            is_rnd=False,
            )
            loss4 = DDQN.train_model(
            online_net=self.policy_network4,
            target_net=self.target_network4,
            optimizer=self.optimizer4,
            collected_experience=collected_experience,
            is_rnd=False,
            )
        
            
            # self.obs_rms.update(new_obs.image)
            new_obs = collected_experience["new_obs"]

        self.learn_step_counter += 1
        if self.reset_multi == 2:
            return loss.item()+loss2.item()
        elif self.reset_multi == 4:
            return loss.item()+loss2.item()+loss3.item()+loss4.item()
        return loss.item()

    def update_target_network(self):
        print("Target network update")
        self.target_network.load_state_dict(self.policy_network.state_dict())
        if self.reset_multi == 2:
            self.target_network2.load_state_dict(self.policy_network2.state_dict())
        elif self.reset_multi == 4:
            self.target_network2.load_state_dict(self.policy_network2.state_dict())
            self.target_network3.load_state_dict(self.policy_network3.state_dict())
            self.target_network4.load_state_dict(self.policy_network4.state_dict())
            
        if utils.check_run.enable_rnd(self):
            self.rnd_target_network.load_state_dict(
                self.rnd_policy_network.state_dict()
            )
        if utils.check_run.enable_optionQ(self):
            self.option_target_network.load_state_dict(
                self.option_policy_network.state_dict()
            )

    def test(self, num_frames, test_return_per_frame_):
        if num_frames % self.test_interval == 0:
            print(f"test start @ num frames: {num_frames}")
            test_return = []
            for _ in range(20):
                test_logs = self.test_collect_experiences()
                test_return_per_episode = utils.synthesize(test_logs["rewards"])
                test_return.append(list(test_return_per_episode.values())[2])
            self.tb_writer.add_scalar(
                "test_return_sum", np.mean(test_return), num_frames / self.test_interval
            )
            test_return_per_frame_.append(np.mean(test_return))

    def test_collect_experiences(self):
        obs = self.eval_env.reset()[0]
        done = False

        log_loss = []
        log_reward = []
        episode_step = 0
        while not done and episode_step < self.max_episode_length:
            episode_step += 1
            preprocessed_obs = self.preprocess_obs([obs], device=self.device)

            action, _ = utils.action.select_action(
                self,
                exploration_type="greedy",
                preprocessed_obs=preprocessed_obs,
                num_frames=None,
            )

            new_obs, reward, done, _, _ = self.eval_env.step(action)
            log_reward.append(reward)
            obs = new_obs

        logs = {"num_frames": None, "rewards": log_reward, "loss": log_loss}
        self.logs = logs
        return logs
