from datetime import datetime
import os
import time

from gym.spaces import Space

import numpy as np
import statistics
from collections import deque

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

from algorithms.legorl.ppobc import RolloutStorage, BCTransformer, ActorCritic
from torch.nn import functional as F
import math

class PPO:

    def __init__(self,
                 vec_env,
                 actor_critic_class,
                 num_transitions_per_env,
                 num_learning_epochs,
                 num_mini_batches,
                 clip_param=0.2,
                 gamma=0.998,
                 lam=0.95,
                 init_noise_std=1.0,
                 value_loss_coef=1.0,
                 entropy_coef=0.0,
                 learning_rate=1e-3,
                 max_grad_norm=0.5,
                 use_clipped_value_loss=True,
                 schedule="fixed",
                 desired_kl=None,
                 model_cfg=None,
                 device='cpu',
                 sampler='sequential',
                 log_dir='run',
                 is_testing=False,
                 print_log=True,
                 apply_reset=False,
                 asymmetric=False
                 ):

        if not isinstance(vec_env.observation_space, Space):
            raise TypeError("vec_env.observation_space must be a gym Space")
        if not isinstance(vec_env.state_space, Space):
            raise TypeError("vec_env.state_space must be a gym Space")
        if not isinstance(vec_env.action_space, Space):
            raise TypeError("vec_env.action_space must be a gym Space")
        self.observation_space = vec_env.observation_space
        self.action_space = vec_env.action_space
        self.state_space = vec_env.state_space

        self.device = device
        self.asymmetric = asymmetric

        self.desired_kl = desired_kl
        self.schedule = schedule
        self.step_size = learning_rate

        # PPO components
        self.vec_env = vec_env
        self.actor_critic = ActorCritic(self.vec_env, self.observation_space.shape, self.state_space.shape, self.action_space.shape,
                                               init_noise_std, model_cfg, asymmetric=asymmetric)
        self.actor_critic.to(self.device)
        self.storage = RolloutStorage(self.vec_env.num_envs, num_transitions_per_env, self.observation_space.shape,
                                      self.state_space.shape, self.action_space.shape, self.device, sampler)
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)

        # PPO parameters
        self.clip_param = clip_param
        self.num_learning_epochs = num_learning_epochs
        self.num_mini_batches = num_mini_batches
        self.num_transitions_per_env = num_transitions_per_env
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.gamma = gamma
        self.lam = lam
        self.max_grad_norm = max_grad_norm
        self.use_clipped_value_loss = use_clipped_value_loss

        # Log
        self.log_dir = log_dir
        self.print_log = print_log
        self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
        self.tot_timesteps = 0
        self.tot_time = 0
        self.is_testing = is_testing
        self.current_learning_iteration = 0
        self.current_learning_steps = 0

        self.apply_reset = apply_reset

        self.ce_loss = nn.CrossEntropyLoss()
        self.seq_len = 3
        self.num_eval_envs = int(self.vec_env.num_envs / 8)
        self.num_eval_envs = int(self.vec_env.num_envs / 8)

        self.gpt_base_lr = 1e-4

        self.obs_buffer = torch.zeros((self.vec_env.num_envs, self.seq_len, 30), dtype=torch.float32, device=self.device)
        self.act_buffer = torch.zeros((self.vec_env.num_envs, self.seq_len, 23), dtype=torch.float32, device=self.device)
        from .GPT_policy import GPT_wrapper
        self.gpt_model = GPT_wrapper(feat_dim=128, n_layer=4, n_head=4, gmm_modes=5, obs_dim=30, action_dim=23).to(self.device)
        self.gpt_optimizer = optim.Adam(self.gpt_model.parameters(), lr=self.gpt_base_lr)
        self.gpt_model.train()
        self.mini_batch_size = 4096
        self.num_learning_steps = 50000

        self.success_storage = RolloutStorage(self.vec_env.num_envs, self.vec_env.task.max_episode_length, self.obs_buffer.shape[1:],
                                      self.state_space.shape, self.act_buffer.shape[1:], self.device, sampler)

        self.loaded_model = False
        self.test("")

    def test(self, path):
        if self.loaded_model:
            return
        # load policy
        import copy
        from algorithms.legorl.ppobc.robot_controller.nn_controller import build_network
        from algorithms.legorl.ppobc.robot_controller.nn_controller import NNController

        paths = ["/home/jmji/DexterousHandEnvs/dexteroushandenvs/runs/AllegroHandLegoTestPAI_20-09-48-42/nn/AllegroHandLegoTestPAI.pth",
                 "/home/jmji/DexterousHandEnvs/dexteroushandenvs/runs/AllegroHandLegoTestPAInsert_20-07-26-55/nn/AllegroHandLegoTestPAInsert.pth"]
        configs = ['/home/jmji/DexterousHandEnvs/dexteroushandenvs/algorithms/legorl/ppobc/robot_controller/test_network.yaml',
                  '/home/jmji/DexterousHandEnvs/dexteroushandenvs/algorithms/legorl/ppobc/robot_controller/grasp_network.yaml']
        self.teacher_policies = []
        for i in range(2):
            # self.teacher_policy = NNController(num_actors=1, config_path=configs[i], obs_dim=self.vec_env.num_obs)
            # self.teacher_policy.load(paths[i])
            # self.teacher_policy.model.eval()
            # self.teacher_policies.append(self.teacher_policy)
            self.teacher_policies.append([])

        from utils.sequence_controller.nn_controller import SeqNNController

        for i in range(1):
            self.teacher_policy = SeqNNController(num_actors=self.vec_env.task.num_envs, obs_dim=30)
            self.teacher_policy.load("/home/jmji/DexterousHandEnvs/dexteroushandenvs/runs/AllegroHandLegoTestOrient_24-08-53-13/nn/last_AllegroHandLegoTestOrient_ep_23000_rew_47.360867.pth", None)
            self.teacher_policies.append(self.teacher_policy)

        self.sequence_rollout = [150, 75, 150]
        self.vec_env.task.max_episode_length = self.sequence_rollout[2]
        self.selected_policy = 2

        self.loaded_model = True
        print("LOADED MODEL")

    def save(self, path):
        torch.save(self.actor_critic.state_dict(), path)

    def load_transformer(self, path):
        self.gpt_model.load_state_dict(torch.load(path))
        self.current_learning_steps = int(path.split("_")[-6])
        print("current_learning_steps: ", self.current_learning_steps)
        self.gpt_model.train()

    def run(self, num_learning_iterations, log_interval=1):
        current_obs = self.vec_env.reset()
        current_states = self.vec_env.get_state()
        last_actions = torch.zeros((self.vec_env.num_envs, 23), dtype=torch.float32, device=self.device)

        rewbuffer = deque(maxlen=100)
        cur_reward_sum = torch.zeros(self.vec_env.num_envs, dtype=torch.float, device=self.device)
        reward_sum = []
        
        infos = {}
        current_student_obs = torch.zeros_like(current_obs[:, 0:30])

        lossbuffer = deque(maxlen=self.vec_env.task.max_episode_length)

        # self.use_success_traj = True
        self.test_transformer = False
        # self.load_transformer("/home/jmji/DexterousHandEnvs/dexteroushandenvs/logs/allegro_hand_lego_test_p_a_i/lego_bc/lego_bc_seed901/model_17000_eval_rew_0.18412307158112526_bcloss_-71.61207605997721.pt")
        # self.load_transformer("/home/jmji/DexterousHandEnvs/dexteroushandenvs/logs/allegro_hand_lego_test_p_a_i/lego_bc/lego_bc_seed900/model_15000_loss_-52.815940602620444.pt")

        if not self.test_transformer:
            for it in range(self.current_learning_steps, self.num_learning_steps):
                with torch.no_grad():
                    # Compute the action
                    # actions = torch.zeros((self.vec_env.num_envs, 23), dtype=torch.float32, device=self.device)
                    teacher_actions = self.teacher_policies[self.selected_policy].predict(current_obs, deterministic=True)
                    # student_actions = self.gpt_model.forward_step(current_obs[:, :30].unsqueeze(0)).detach()
                    # actions[:] = teacher_actions[:].clone()
                    # actions[:] = student_actions[:].clone()
                        # Step the vec_environment
                        # self.apply_env_id = torch.where(self.vec_env.task.progress_buf <= self.sequence_rollout[i+1],
                        #                         torch.where(self.vec_env.task.progress_buf >= self.sequence_rollout[i], 1, 0), 0).nonzero(as_tuple=False)
                        # selected_action[self.apply_env_id] = actions[self.apply_env_id]

                    next_obs, rews, dones, infos = self.vec_env.step(teacher_actions)

                # bc
                for i in range(self.seq_len):
                    if i == self.seq_len - 1:
                        self.obs_buffer[:, i] = current_student_obs[:, :30]
                        self.act_buffer[:, i] = self.vec_env.task.bc_act_label[:, :23]
                    else:
                        self.obs_buffer[:, i] = self.obs_buffer[:, i + 1]
                        self.act_buffer[:, i] = self.act_buffer[:, i + 1]

                predict_actions_dist = self.gpt_model.forward_train(self.obs_buffer)

                loss = (- predict_actions_dist.log_prob(self.act_buffer)).mean()

                self.gpt_optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm)
                self.gpt_optimizer.step()

                lossbuffer.append(loss.item())
                # check if traj is success
                # if self.use_success_traj:
                #     success_buf = infos["success_buf"].clone()
                #     success_ids = torch.where(dones > 0, 
                #                     torch.where(success_buf > 0, 1, 0), 0).nonzero(as_tuple=False).squeeze(-1)
                # else:
                #     success_buf = infos["success_buf"].clone()
                #     success_ids = torch.where(dones > 0, 1, 0).nonzero(as_tuple=False).squeeze(-1)

                # self.success_storage.add_transitions(self.obs_buffer, current_states, self.act_buffer, rews, dones, success_ids)

                current_obs.copy_(next_obs)
                current_student_obs.copy_(infos["student_obs_buf"])

                lr = self.adjust_lr(
                    self.gpt_optimizer, self.gpt_base_lr, it, 1000, 50000
                )

                # warm up
                # if self.success_storage.success_step > self.mini_batch_size:
                #     loss = self.update_bc()
                #     lossbuffer.append(loss.item())

                # cur_reward_sum[:] += rews

                # eval_dones = dones.clone()
                # eval_dones[:-self.num_eval_envs] = 0
                # new_ids = (eval_dones > 0).nonzero(as_tuple=False)
                # reward_sum.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
                # cur_reward_sum[new_ids] = 0

                # rewbuffer.extend(reward_sum)

                self.log(locals())

                if it % 1000 == 0 and len(lossbuffer) > 0:
                    print("saving {} it".format(it))
                    torch.save(self.gpt_model.state_dict(), os.path.join(self.log_dir, 'model_{}_bcloss_{}.pt'.format(it, statistics.mean(lossbuffer))))

        if self.test_transformer:
            # no perturb
            from utils.transformer_controller.nn_controller import NNController
            self.seq_policy = NNController(num_actors=self.num_envs, obs_dim=30)
            self.seq_policy.load(["/home/jmji/DexterousHandEnvs/dexteroushandenvs/logs/allegro_hand_lego_test_p_a_i/lego_bc/lego_bc_seed900/model_15000_loss_-52.815940602620444.pt",
							"/home/jmji/DexterousHandEnvs/dexteroushandenvs/logs/allegro_hand_lego_test_p_a_insert/lego_bc/lego_bc_seed22/model_2000_eval_rew_0.015382986245676876_bcloss_-45.931198018391925.pt"])
            self.switch_policy = False

            self.gpt_model.eval()
            while True:
                with torch.no_grad():
                    # Compute the action
                    # actions = self.teacher_policy.predict(current_obs)
                    actions = self.gpt_model.forward_step(current_student_obs[:, :30]).detach()
                    # Step the vec_environment
                    next_obs, rews, dones, infos = self.vec_env.step(actions)
                    current_obs.copy_(next_obs)
                    current_student_obs.copy_(infos["student_obs_buf"][:, :30])

    def log(self, locs, width=80, pad=35):

        str = f" \033[1m Learning step {locs['it']}/{self.num_learning_steps} \033[0m "

        if len(locs['lossbuffer']) > 0:
            log_string = (f"""{'#' * width}\n"""
                          f"""{str.center(width, ' ')}\n\n"""
                          f"""{'Mean BC loss:':>{pad}} {statistics.mean(locs['lossbuffer']):.2f}\n"""
                          f"""{'success_storage.success_step:':>{pad}} {self.success_storage.success_step}\n"""
                          f"""{'lr:':>{pad}} {locs['lr']}\n""")
                        # f"""{'Mean Eval reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""")
            self.writer.add_scalar('Loss/bc_loss', statistics.mean(locs['lossbuffer']), locs['it'])
        else:
            log_string = (f"""{'success_storage.success_step:':>{pad}} {self.success_storage.success_step}\n"""
                          f"""{'#' * width}\n"""
                          f"""{str.center(width, ' ')}\n\n""")
        log_string += (f"""{'-' * width}\n""")
        print(log_string)

    def update_bc(self):
        batch = self.success_storage.mini_batch_generator(mini_batch_size=self.mini_batch_size)
        for indices in batch:
            selected_indices = indices
            break

        obs_buf_batch = self.success_storage.success_observations[selected_indices]
        actions_buf_batch = self.success_storage.success_actions[selected_indices]

        predict_actions_dist = self.gpt_model.forward_train(obs_buf_batch)

        loss = (- predict_actions_dist.log_prob(actions_buf_batch)).mean()

        self.gpt_optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm)
        self.gpt_optimizer.step()

        return loss
    
    def adjust_lr(self, optimizer, base_lr, cur_epoch, warmup_epoch, num_epoch):
        if cur_epoch < warmup_epoch:
            lr = base_lr * cur_epoch / warmup_epoch
        else:
            lr = base_lr * 0.5 * (1. + math.cos(math.pi * (cur_epoch - warmup_epoch) / (num_epoch - warmup_epoch)))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        return lr
