import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os

import gymnasium as gym
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from mani_skill.utils.wrappers.flatten import FlattenRGBDObservationWrapper

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from vanilla_transformer.transformer import Model, CustomTransformerEncoder


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class StateDictWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super(StateDictWrapper, self).__init__(env)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)  
        
        return {'state': obs}, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)  
        
        return {'state': obs}, reward, terminated, truncated, info 
    
    

def env_constructor(env_name, num_envs, obs_mode, reconf_freq=None):
    if obs_mode == 'state':
        env_kwargs = dict(obs_mode=obs_mode, sim_backend="gpu", control_mode="pd_joint_delta_pos")
        env = gym.make(env_name, num_envs=num_envs, reconfiguration_freq=reconf_freq, **env_kwargs)
        env = ManiSkillVectorEnv(env, num_envs, ignore_terminations=True, record_metrics=True)
        env = StateDictWrapper(env)
        s_d = env.observation_space.shape[-1]
        a_d = env.action_space.shape[-1]
        return env, s_d, a_d
    elif obs_mode == 'rgb':
        env_kwargs = dict(obs_mode=obs_mode, sim_backend="gpu", control_mode="pd_joint_delta_pos")
        env = gym.make(env_name, num_envs=num_envs, reconfiguration_freq=reconf_freq, **env_kwargs)
        env = FlattenRGBDObservationWrapper(env, rgb=True, depth=False, state=True)
        env = ManiSkillVectorEnv(env, num_envs=num_envs, ignore_terminations=True, record_metrics=True)
        s_d = env.observation_space['state'].shape[-1]
        a_d = env.action_space.shape[-1]
        return env, s_d, a_d
    elif obs_mode == 'rgbd':
        env_kwargs = dict(obs_mode=obs_mode, sim_backend="gpu", control_mode="pd_joint_delta_pos")
        env = gym.make(env_name, num_envs=num_envs, reconfiguration_freq=reconf_freq, **env_kwargs)
        env = FlattenRGBDObservationWrapper(env, rgb=True, depth=True, state=True)
        env = ManiSkillVectorEnv(env, num_envs=num_envs, ignore_terminations=True, record_metrics=True)
        s_d = env.observation_space['state'].shape[-1]
        a_d = env.action_space.shape[-1]
        return env, s_d, a_d

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l2_2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, action_dim)
        
        self.max_action = max_action
        

    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        a = F.relu(self.l2_2(a))
        return self.max_action * torch.tanh(self.l3(a))


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l2_2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l5_2 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)


    def forward(self, state, action):
        
        sa = torch.cat([state, action], -1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = F.relu(self.l2_2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = F.relu(self.l5_2(q2))
        q2 = self.l6(q2)
        return q1, q2


    def Q1(self, state, action, img_state=None):
        
        sa = torch.cat([state, action], -1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = F.relu(self.l2_2(q1))
        q1 = self.l3(q1)
        return q1

class lstm_Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=512):
        super(lstm_Critic, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.lstm1 = nn.LSTM(state_dim, hidden_dim, num_layers=1, batch_first=True)
        self.fc1 = nn.Linear(hidden_dim+action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
        
        
        self.lstm2 = nn.LSTM(state_dim, hidden_dim, num_layers=1, batch_first=True)
        self.fc3 = nn.Linear(hidden_dim+action_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):

        n_e, bs, cont, s_d = state.shape
        state = state.view(-1, cont, s_d) 

        # Q1
        lstm_out1, _ = self.lstm1(state)  # num_envs*bs, context, sa_d
        lstm_out1 = lstm_out1.view(n_e, bs, cont, self.hidden_dim)  
        lstm_out1 = lstm_out1[:,:,-1,:]  #num_envs, bs, self.hidden_dim
        
        sa = torch.cat([lstm_out1, action], dim=-1) # n_e, b_s, self.hidden_dim+a_d
        q1 = F.relu(self.fc1(sa))
        q1 = self.fc2(q1)


        # Q2
        lstm_out2, _ = self.lstm2(state)  # lstm_out2 (batch, seq_len, hidden_dim)
        lstm_out2 = lstm_out2.view(n_e, bs, cont, self.hidden_dim)
        lstm_out2 = lstm_out2[:,:,-1,:]
        
        sa = torch.cat([lstm_out2, action], dim=-1)
        q2 = F.relu(self.fc3(sa))
        q2 = self.fc4(q2)

        return q1, q2
    
    
    def Q1(self, state, action):
        
        n_e, bs, cont, s_d = state.shape
        state = state.view(-1, cont, s_d) 

        # Q1
        lstm_out1, _ = self.lstm1(state)  # num_envs*bs, context, sa_d
        lstm_out1 = lstm_out1.view(n_e, bs, cont, self.hidden_dim)  
        lstm_out1 = lstm_out1[:,:,-1,:]  #num_envs, bs, self.hidden_dim
        
        sa = torch.cat([lstm_out1, action], dim=-1) # n_e, b_s, self.hidden_dim+a_d
        q1 = F.relu(self.fc1(sa))
        q1 = self.fc2(q1)

        return q1
#########################################################################################################
#########################################################################################################
#########################################################################################################
class transformer_Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256, num_heads=2, num_layers=1):
        super(transformer_Critic, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.act_encoder = nn.Linear(state_dim, hidden_dim)
        
        self.transformer_encoder = CustomTransformerEncoder(hidden_dim, num_heads, 512, 0.05, False, False, True, 'GRU', 'Trans')

        
        self.fc1 = nn.Linear(hidden_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

        self.fc3 = nn.Linear(hidden_dim + action_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):
        n_e, bs, cont, s_d = state.shape
        state = state.view(-1, cont, s_d)  #  n_e*bs, context, s_d
        state = self.act_encoder(state)
        
        transformer_out = self.transformer_encoder(state)  # n_e*bs, context, d_m
        transformer_out = transformer_out[:, -1, :]        # n_e*bs, d_m
        transformer_out = transformer_out.view(n_e, bs, self.hidden_dim)  # n_e, bs, d_m
        # Q1
        sa = torch.cat([transformer_out, action], dim=-1)   # n_e, bs, d_m+a_d
        q1 = F.relu(self.fc1(sa))
        q1 = self.fc2(q1)

        # Q2
        q2 = F.relu(self.fc3(sa))
        q2 = self.fc4(q2)

        return q1, q2

    def Q1(self, state, action):
        n_e, bs, cont, s_d = state.shape
        state = state.view(-1, cont, s_d)  #  n_e*bs, context, s_d
        state = self.act_encoder(state)
        #
        transformer_out = self.transformer_encoder(state)  # n_e*bs, context, d_m
        transformer_out = transformer_out[:, -1, :]        # n_e*bs, d_m
        transformer_out = transformer_out.view(n_e, bs, self.hidden_dim)  # n_e, bs, d_m
        # Q1
        sa = torch.cat([transformer_out, action], dim=-1)   # n_e, bs, d_m+a_d
        q1 = F.relu(self.fc1(sa))
        q1 = self.fc2(q1)

        return q1



    
class TD3(object):
    def __init__(
        self,
        num_envs,
        obs_mode,
        context_length,
        model_config,
        state_dim,
        action_dim,
        max_action,
        discount,
        tau,
        policy_noise,
        noise_clip,
        policy_freq
):

        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = copy.deepcopy(self.actor).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        self.critic = Critic(state_dim, action_dim).to(device)
        #self.critic = transformer_Critic(state_dim, action_dim, hidden_dim=512, num_heads=2, num_layers=1).to(device)
        #self.critic = lstm_Critic(state_dim, action_dim).to(device)  
        self.critic_target = copy.deepcopy(self.critic).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
        

        self.trans = Model(**model_config, state_dim=state_dim, act_dim=action_dim, obs_mode='state').to(device)
        self.trans_target = copy.deepcopy(self.trans).to(device)
        self.trans_optimizer = torch.optim.Adam(self.trans.parameters(), lr=3e-4)
        

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.context_length = context_length
        self.obs_mode = obs_mode
        self.total_it = 0
        self.eval_counter = 0
        
        

    def select_action(self, state):
        
        
        return self.actor(state).cpu().data.numpy()
    
    def stage_2_train(self, batch_size):
        self.total_it += 1
        train_batch = self.new_trans_RB.sample(batch_size)
        if self.obs_mode == 'state':
            states, actions, rewards, dones, next_states = train_batch
        else:
            states, actions, rewards, dones, next_states, img_states, img_next_states = train_batch 
            img_states, img_next_states = img_states.to(device).requires_grad_(True), img_next_states.to(device).requires_grad_(True)	
        
        states = states.to(device).requires_grad_(True)											#n_e, bs, context, state_dim
        actions = actions.to(device).requires_grad_(True)										#n_e, bs, action_dim
        rewards = rewards.to(device).requires_grad_(True)										#n_e, bs, 1
        dones = dones.to(device)											                    #n_e, bs, 1
        next_states = next_states.to(device).requires_grad_(True)								#n_e, bs, context, state_dim

        self.trans.train()
        self.critic.train()

        with torch.no_grad():
            noise = (                                                           #n_e, bs, a_d
                torch.randn_like(actions) * self.policy_noise  
            ).clamp(-self.noise_clip, self.noise_clip)
            
            if self.obs_mode == 'state':
                next_action = (
                    self.trans_target.actor_forward(next_states) + noise  			 #next_action = (n_e, bs, a_d)
                ).clamp(-self.max_action, self.max_action)
            else:
                next_action = self.trans_target.actor_forward(next_states, img_next_states)
                noise = ( torch.randn_like(next_action) * self.policy_noise ).clamp(-self.noise_clip, self.noise_clip)  
                next_action = (next_action + noise).clamp(-self.max_action, self.max_action)
            
            target_Q1, target_Q2 = self.critic_target(next_states[:,:,-1,:], next_action) if self.obs_mode == 'state' else self.critic_target(next_states[:,:,-1,:], next_action, img_next_states[:,:,-1,])
            #target_Q1, target_Q2 = self.critic_target(next_states, next_action)
            target_Q = torch.min(target_Q1, target_Q2)                                      #target_Q = (n_e, bs, 1)
            target_Q = rewards + (1-dones) * self.discount * target_Q       #target_Q = (n_e, bs, 1) + (n_e, bs, 1) * const * (n_e, bs, 1)

        current_Q1, current_Q2 = self.critic(states[:,:,-1,:], actions) if self.obs_mode == 'state' else self.critic(states[:,:,-1,:], actions, img_states[:,:,-1,])  #current_Q1 = (n_e, bs, 1)
        #current_Q1, current_Q2 = self.critic(states, actions)
        

        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        self.experiment.add_scalar('Critic_loss', critic_loss.item(), self.total_it)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()

        critic_grad_norm = sum(p.grad.norm().item() for p in self.critic.parameters() if p.grad is not None)
        self.experiment.add_scalar('critic_grad_norm', critic_grad_norm, self.total_it)
    
        self.critic_optimizer.step()

        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:

            # Compute actor losse
            if self.obs_mode == 'state':
                trans_loss = -self.critic.Q1(states[:,:,-1,:], self.trans.actor_forward(states)).mean()
                #trans_loss = -self.critic.Q1(states, self.trans.actor_forward(states)).mean()
            else:
                trans_loss = -self.critic.Q1(states[:,:,-1,:], self.trans.actor_forward(states, img_states), img_states[:,:,-1,]).mean()  
            self.experiment.add_scalar('Actor_loss', trans_loss, self.total_it)
            
            # Optimize the actor 
            self.trans_optimizer.zero_grad()
            trans_loss.backward()
            trans_grad_norm = sum(p.grad.norm().item() for p in self.trans.parameters() if p.grad is not None)
            self.experiment.add_scalar('actor_grad_norm', trans_grad_norm, self.total_it)

            self.trans_optimizer.step()

            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.trans.parameters(), self.trans_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)



    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
        
        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")


    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
        self.actor_target = copy.deepcopy(self.actor)
        

############################################################################################################################
class New_Trans_RB():
    def __init__(self, num_envs, size, context, state_dim, act_dim, obs_mode):
        self.size = size
        self.context = context
        self.idx = 0
        self.overfilled = False
        self.obs_mode = obs_mode

        self.observations = torch.zeros((num_envs, size, context, state_dim), dtype=torch.float32)
        self.actions = torch.zeros((num_envs, size, act_dim), dtype=torch.float32)
        self.returns = torch.zeros((num_envs, size, 1), dtype=torch.float32)
        self.dones = torch.zeros((num_envs, size, 1), dtype=torch.float32)
        self.next_observations = torch.zeros((num_envs, size, context, state_dim), dtype=torch.float32)
        
        if obs_mode != 'state':
            channels = 3 if obs_mode == 'rgb' else 4
            self.img_observations = torch.zeros((num_envs, size, context, 128, 128, channels), dtype=torch.float32)
            self.img_next_observations = torch.zeros((num_envs, size, context, 128, 128, channels), dtype=torch.float32)
            
    
    def recieve_traj(self, obs, acts, rets, dones, next_obs, img_obs=None, img_next_obs=None):
        

        obs = torch.stack(obs, dim=1)                           # n_e, cont, s_d
        acts = acts.to(torch.float32)                           # n_e, a_d
        rets = rets.to(torch.float32)                           # n_e, 1
        dones = dones                                           # n_e, 1
        next_obs = torch.stack(next_obs, dim=1)                 # n_e, cont, s_d
        if self.obs_mode != 'state':
            img_obs = torch.stack(img_obs, dim=1).float() 
            img_next_obs = torch.stack(img_next_obs, dim=1).float() 
            img_obs[:, :, :, :, :3] /= 225.0
            img_next_obs[:, :, :, :, :3] /= 225.0
            self.img_observations[:,self.idx,] = img_obs
            self.img_next_observations[:,self.idx,] = img_next_obs
        self.observations[:,self.idx,] = obs
        self.actions[:,self.idx,] = acts
        self.returns[:,self.idx,] = rets
        self.dones[:,self.idx,] = dones
        self.next_observations[:,self.idx,] = next_obs

        self.idx += 1
        if self.idx >= self.size:
            self.idx = 0
            self.overfilled = True

    def sample(self, batch_size):
        if batch_size > self.size:
            raise ValueError("batch > size")
        
        elif (batch_size >= self.idx) and (self.overfilled == False):
            idxs = torch.randperm(self.idx) # 
            
        elif (batch_size >= self.idx) and (self.overfilled == True):
            idxs = torch.randperm(self.size)[:batch_size]
        
        elif (batch_size < self.idx) and (self.overfilled == False):
            idxs = torch.randperm(self.idx)[:batch_size]
        
        elif (batch_size < self.idx) and (self.overfilled == True):
            idxs = torch.randperm(self.size)[:batch_size]
        
        
        

        batch = (
            self.observations[:,idxs, ],
            self.actions[:,idxs, ],
            self.returns[:,idxs, ],
            self.dones[:,idxs, ],
            self.next_observations[:,idxs, ],
        ) if self.obs_mode == 'state' else (
            self.observations[:,idxs, ],
            self.actions[:,idxs, ],
            self.returns[:,idxs, ],
            self.dones[:,idxs, ],
            self.next_observations[:,idxs, ],
            self.img_observations[:,idxs, ],
            self.img_next_observations[:,idxs, ]
        ) 


        return batch



