from torch import optim
from torch.nn.modules.loss import MSELoss
from .stable_ppo.custom_rollout_buffers import RolloutBuffer
from .stable_sac.custom_replaybuffer import ReplayBuffer
import gym
import torch
import torch.nn as nn
from .models import EnvModel
import numpy as np


class ModelLearner():
    def __init__(self,
                 num_teachers,
                 num_students,
                 observation_space,
                 action_space,
                 env,
                 n_steps,
                 device,
                 gamma,
                 gae_lambda) -> None:
        self.num_teachers = num_teachers
        self.num_students = num_students,
        self.teacher_learners = []
        self.student_learners = []
        self.teacher_optims = []
        self.student_optims = []
        self.buffers = []
        self.env = env
        self.n_steps = n_steps
        self.obs_dim = observation_space.shape[0]
        self.action_dim = action_space.shape[0]  # action_space
        self.device = device
        self.gamma = gamma
        self.gae_labda = gae_lambda
        self.loss_func = nn.MSELoss()
        # self.student_learner = EnvModel(self.obs_dim, self.action_dim)
        # self.student_optim = torch.optim.Adam(
        # self.student_learner.parameters())
        self.student_buffer = RolloutBuffer(
            self.n_steps,
            observation_space,
            action_space,
            self.device,
            gamma=self.gamma,
            gae_lambda=gae_lambda,
        )
        for _ in range(self.num_teachers):
            teacher_learner = EnvModel(self.obs_dim, self.action_dim)
            self.teacher_learners.append(teacher_learner)
            self.teacher_optims.append(
                torch.optim.Adam(teacher_learner.parameters()))
            rollout_buffer = RolloutBuffer(
                self.n_steps,
                observation_space,
                action_space,
                self.device,
                gamma=self.gamma,
                gae_lambda=self.gae_labda,
            )
            self.buffers.append(rollout_buffer)
        for _ in range(num_students):
            student_learner = EnvModel(self.obs_dim, self.action_dim)
            self.student_learners.append(student_learner)
            self.student_optims.append(
                torch.optim.Adam(student_learner.parameters()))

    def train(self, batch_size):
        """
        Train individual sector models.
        """
        # Todo: normalize inputs
        i = 0
        for learner, buffer, optim in zip(self.teacher_learners, self.buffers, self.teacher_optims):

            for data in buffer.get(batch_size):
                states = data.observations
                actions = data.actions
                next_states = data.next_observations
                rewards = data.rewards

                # batch training
                optim.zero_grad()
                pred_states, pred_rewards = learner(states, actions)
                pred_states = pred_states.view(-1, 1)
                next_states = next_states.view(-1, 1)
                loss = self.loss_func(pred_states, next_states) +\
                    self.loss_func(pred_rewards, rewards)
                # print("model {} learing loss:{}".format(i, loss))
                loss.backward()
                optim.step()
            i += 1

    def train_student(self, buffer, batch_size):
        for data, student_learner, student_optim in zip(buffer.get(batch_size),
                                                self.student_learners,
                                                self.student_optims):
            states = data.observations
            actions = data.actions
            next_states = data.next_observations
            # next_states = data.next_observations.view(batch_size, -1)
            rewards = data.rewards

            student_optim.zero_grad()
            stu_pred_states, stu_pred_rewards = student_learner(
                states, actions)
            aggre_tea_pred_states, aggre_tea_pred_rewards = 0, 0
            for teacher in self.teacher_learners:
                with torch.no_grad():
                    tea_pred_states, tea_pred_rewards = teacher(
                        states, actions)
                    aggre_tea_pred_states += tea_pred_states
                    aggre_tea_pred_rewards += tea_pred_rewards
            loss = self.loss_func(aggre_tea_pred_states /
                                  len(self.teacher_learners), stu_pred_states) +\
                self.loss_func(aggre_tea_pred_rewards /
                               len(self.teacher_learners), stu_pred_rewards) +\
                self.loss_func(next_states, stu_pred_states) + \
                self.loss_func(rewards, stu_pred_rewards)
            loss.backward()
            student_optim.step()

    def sample_rollouts(self,
                        policy,
                        n_rollout_steps
                        ):
        for learner, rollout_buffer in zip(self.teacher_learners, self.buffers):
            rollout_buffer.reset()
            obs = self.env.reset()
            n_steps = 0

            while n_steps < n_rollout_steps:

                with torch.no_grad():
                    obs_tensor = torch.as_tensor(obs).to(self.device)
                    # print(n_steps, obs_tensor)
                    actions_tensor, values, log_probs = policy.forward(
                        obs_tensor)
                actions = actions_tensor.cpu().numpy()

                clipped_actions = actions
                if isinstance(self.action_dim, gym.spaces.Box):
                    clipped_actions = np.clip(
                        actions, self.action_dim.low, self.action_dim.high
                    )

                # next_obs, rewards, dones, infos = self.env.step(clipped_actions)
                with torch.no_grad():
                    next_obs, rewards = learner(
                        obs_tensor, torch.from_numpy(clipped_actions).to(self.device))
                next_obs = next_obs.detach().cpu()
                n_steps += 1
                dones = False
                if isinstance(self.action_dim, gym.spaces.Discrete):
                    actions = actions.reshape(-1, 1)
                rollout_buffer.add(
                    obs, actions, rewards.detach().cpu(), dones, values, log_probs, next_obs)
                if torch.any((torch.isnan(next_obs))) or torch.any((torch.isinf(next_obs))):
                    print("self env reset for steps {}".format(n_steps))
                    obs = self.env.reset()
                else:
                    obs = next_obs
            with torch.no_grad():
                obs_tensor = torch.as_tensor(next_obs).to(self.device)
                _, values, _ = policy.forward(obs_tensor)
            rollout_buffer.compute_returns_and_advantage(
                last_values=values, dones=dones)
        return self.buffers

    def sample_student_rollouts(self,
                                policy,
                                n_rollout_steps
                                ):
        rollout_buffer = self.student_buffer
        rollout_buffer.reset()
        obs = self.env.reset()
        n_steps = 0

        while n_steps < n_rollout_steps:

            with torch.no_grad():
                obs_tensor = torch.as_tensor(obs).to(self.device)
                # print(n_steps, obs_tensor)
                actions_tensor, values, log_probs = policy.forward(
                    obs_tensor)
            actions = actions_tensor.cpu().numpy()

            clipped_actions = actions
            if isinstance(self.action_dim, gym.spaces.Box):
                clipped_actions = np.clip(
                    actions, self.action_dim.low, self.action_dim.high
                )

            # next_obs, rewards, dones, infos = self.env.step(clipped_actions)
            with torch.no_grad():
                next_obs_list = []
                rewards_list = []
                for student_learner in self.student_learners:
                    next_obs, rewards = student_learner(
                        obs_tensor, torch.from_numpy(clipped_actions).to(self.device))
                    next_obs_list.append(next_obs)
                    rewards_list.append(rewards)
            next_obs = torch.mean(torch.stack(next_obs_list), dim=0)
            rewards = torch.mean(torch.stack(rewards_list),dim=0)
            next_obs = next_obs.detach().cpu()
            n_steps += 1
            dones = False
            if isinstance(self.action_dim, gym.spaces.Discrete):
                actions = actions.reshape(-1, 1)
            rollout_buffer.add(
                obs, actions, rewards.detach().cpu(), dones, values, log_probs, next_obs)
            if torch.any((torch.isnan(next_obs))) or torch.any((torch.isinf(next_obs))):
                print("self env reset for steps {}".format(n_steps))
                obs = self.env.reset()
            else:
                obs = next_obs
        with torch.no_grad():
            obs_tensor = torch.as_tensor(next_obs).to(self.device)
            _, values, _ = policy.forward(obs_tensor)
        rollout_buffer.compute_returns_and_advantage(
            last_values=values, dones=dones)
        return self.student_buffer

    def student_sample_rollouts(self,
                                policy,
                                n_rollout_steps
                                ):
        self.student_buffer.reset()
        obs = self.env.reset()
        n_steps = 0

        while n_steps < n_rollout_steps:

            with torch.no_grad():
                obs_tensor = torch.as_tensor(obs).to(self.device)
                # print(n_steps, obs_tensor)
                actions_tensor, values, log_probs = policy.forward(
                    obs_tensor)
            actions = actions_tensor.cpu().numpy()

            clipped_actions = actions
            if isinstance(self.action_dim, gym.spaces.Box):
                clipped_actions = np.clip(
                    actions, self.action_dim.low, self.action_dim.high
                )

            # next_obs, rewards, dones, infos = self.env.step(clipped_actions)
            with torch.no_grad():
                next_obs_list = []
                rewards_list = []
                for student_learner in self.student_learners:
                    next_obs, rewards = student_learner(
                        obs_tensor, torch.from_numpy(clipped_actions).to(self.device))
                    next_obs_list.append(next_obs)
                    rewards_list.append(rewards)
            next_obs = torch.mean(torch.stack(next_obs_list), dim=0)
            rewards = torch.mean(torch.stack(rewards_list),dim=0)
            next_obs = next_obs.detach().cpu()
            n_steps += 1
            dones = False
            if isinstance(self.action_dim, gym.spaces.Discrete):
                actions = actions.reshape(-1, 1)
            self.student_buffer.add(
                obs, actions, rewards.detach().cpu(), dones, values, log_probs, next_obs)
            obs = next_obs
        with torch.no_grad():
            obs_tensor = torch.as_tensor(next_obs).to(self.device)
            _, values, _ = policy.forward(obs_tensor)
        self.student_buffer.compute_returns_and_advantage(
            last_values=values, dones=dones)
        return self.student_buffer


class OffModelLearner():
    def __init__(self,
                 num_secs,
                 observation_space,
                 action_space,
                 env,
                 n_steps,
                 device,
                 buffer_size,
                 n_envs) -> None:
        self.num_secs = num_secs
        self.learners = []
        self.optims = []
        self.buffers = []
        self.env = env
        self.n_steps = n_steps
        self.obs_dim = observation_space.shape[0]
        self.action_dim = action_space.shape[0]  # action_space
        self.device = device
        self.loss_func = nn.MSELoss()
        for _ in range(self.num_secs):
            learner = EnvModel(self.obs_dim, self.action_dim)
            self.learners.append(learner)
            self.optims.append(torch.optim.Adam(learner.parameters()))
            replay_buffer = ReplayBuffer(
                buffer_size,
                observation_space,
                action_space,
                self.device,
                n_envs
            )
            self.buffers.append(replay_buffer)

    def train(self, all_buffers, batch_size):
        """
        Train individual sector models.
        """
        # Todo: normalize inputs
        for learner, buffer, optim in zip(self.learners, all_buffers, self.optims):

            for data in buffer.get(batch_size):
                states = data.observations
                actions = data.actions
                next_states = data.next_observations
                rewards = data.rewards

                # batch training
                optim.zero_grad()
                pred_states, pred_rewards = learner(states, actions)
                pred_states = pred_states.view(-1, 1)
                next_states = next_states.view(-1, 1)
                loss = self.loss_func(pred_states, next_states) +\
                    self.loss_func(pred_rewards, rewards)
                loss.backward()
                optim.step()

    def sample_rollouts(self,
                        policy,
                        n_rollout_steps
                        ):
        for learner, rollout_buffer in zip(self.learners, self.buffers):
            rollout_buffer.reset()
            obs = self.env.reset()
            n_steps = 0

            while n_steps < n_rollout_steps:

                with torch.no_grad():
                    obs_tensor = torch.as_tensor(obs).to(self.device)
                    # print(n_steps, obs_tensor)
                    actions_tensor, values, log_probs = policy.forward(
                        obs_tensor)
                actions = actions_tensor.cpu().numpy()

                clipped_actions = actions
                if isinstance(self.action_dim, gym.spaces.Box):
                    clipped_actions = np.clip(
                        actions, self.action_dim.low, self.action_dim.high
                    )

                # next_obs, rewards, dones, infos = self.env.step(clipped_actions)
                with torch.no_grad():
                    next_obs, rewards = learner(
                        obs_tensor, torch.from_numpy(clipped_actions).to(self.device))
                next_obs = next_obs.detach().cpu()
                n_steps += 1
                dones = False
                if isinstance(self.action_dim, gym.spaces.Discrete):
                    actions = actions.reshape(-1, 1)
                rollout_buffer.add(
                    obs, actions, rewards.detach().cpu(), dones, values, log_probs, next_obs)
                obs = next_obs
            with torch.no_grad():
                obs_tensor = torch.as_tensor(next_obs).to(self.device)
                _, values, _ = policy.forward(obs_tensor)
            rollout_buffer.compute_returns_and_advantage(
                last_values=values, dones=dones)
        return self.buffers
