import torch as th
import numpy as np
import os 
import torch.nn as nn 
import torch.nn.functional as F

'''
def load_traj(num=1, save_directory='replay_buffer/',device='cuda') :
    data = ['actions','avail_actions','obs','probs','reward','state','terminated']
    replay_data = dict()
    for d in data :
        replay_data[d] = th.load(save_directory+'{}_{}.pt'.format(num,d)).to(device)
    return replay_data

def get_one_step_reward(reward_func, traj, step, n_agent) :
    r = [0 for i in range(n_agent)]
    s = traj['state'][0,step,:].reshape((1,-1))
    s_next = traj['state'][0,step+1,:].reshape((1,-1))
    for agent in range(n_agent) :
        obs = traj['obs'][0,step,agent,:].reshape((1,-1))
        action = traj['obs'][0,step+1,agent,30:39].reshape((1,-1))
        r[agent] = reward_func(s,s_next,obs,action)
    return r        

def get_loss(num_1,num_2,r_hat,rank) :
    p = (th.exp(r_hat[num_1])+1e-6)/(th.exp(r_hat[num_1])+th.exp(r_hat[num_2])+1e-6)
    if rank[num_1] < rank[num_2] :
        pref = 1
    elif rank[num_1] > rank[num_2] :
        pref = 0 
    else :
        pref = 0.5
    loss_1 = -1 * ((pref* th.log(p) + (1-pref)*th.log(1-p)))
    return loss_1 

def get_reward(reward_func, traj, end, n_agent) :
    #for name, param in reward_func.named_parameters():
        #print(f"Parameter: {name}, Requires_grad: {param.requires_grad}")
        
    r = 0
    s = traj['state'][0,:end,:]
    print(s.requires_grad)

    s_next = traj['state'][0,1:end+1,:]
    print(s_next.requires_grad)

    for agent in range(n_agent) :
        obs = traj['obs'][0,:end,agent,:]
        action = traj['obs'][0,1:end+1,agent,30:39]
        r_step = reward_func(s,s_next,obs,action)
        
        r = r+ th.sum(r_step)
    return r
    

def update_reward_model(args,
                        seq=0,
                        traj_path='replay_buffer/3m_LLM_1_0',
                        device='cuda',
                        model_path='save_reward_model',
                        training_steps=50) :
    
    n_agents = args.n_agents
    traj_path = '{}/{}_{}_{}/'.format(args.replay_buffer_save_path,args.map_name,args.reward_model_info,seq-1)
    replay_data = load_traj(save_directory=traj_path)
    reward_function = reward_model(replay_data['state'].shape[2],replay_data['obs'].shape[3],9).to(device)
    print(reward_function)

    load_model_path = model_path+'/{}_{}_{}.pt'.format(args.map_name,args.reward_model_info,seq-1)
    reward_function.load_state_dict(th.load(load_model_path))
    for param in reward_function.parameters():
        param.requires_grad = True
    opt= th.optim.Adam(reward_function.parameters(),lr=args.reward_model_lr)
    
    save_model_path = model_path+'/{}_{}_{}.pt'.format(args.map_name,args.reward_model_info,seq)

    ### Trajectory Preference
    save_traj_pref_directory = '{}/{}_{}_{}_traj.pt'.format(args.preference_save_path,args.map_name,args.reward_model_info,seq-1)
    pref_traj_data = th.load(save_traj_pref_directory)

    ### Step Preference
    if args.compare_agents :
        save_step_pref_directory = '{}/{}_{}_{}_step.pt'.format(args.preference_save_path,args.map_name,args.reward_model_info,seq-1)
        pref_step_data = th.load(save_step_pref_directory)
        
    for training_step in range(training_steps) :
            ## Trajectory
            loss = th.zeros(1, requires_grad=True).to(device)
            for t_1, t_2, pref in pref_traj_data :
                t_1 = int(t_1.item())
                t_2 = int(t_2.item())
                pref = int(pref.item())
                ### When the length is fixed
                traj_1 = load_traj(num=t_1,save_directory=traj_path)
                traj_2 = load_traj(num=t_2,save_directory=traj_path)
                end_1 = th.where(traj_1['terminated']==1)[1].item()
                end_2 = th.where(traj_2['terminated']==1)[1].item()
                if end_1 >= end_2 :
                    end_state = end_2
                else :
                    end_state = end_1

                r_1 = get_reward(reward_function, traj_1, end_state,n_agents)
                r_2 = get_reward(reward_function, traj_2, end_state,n_agents)
                p = th.exp(r_1) / (th.exp(r_1)+th.exp(r_2))
                if pref == 0 :
                    pref = 0.5
                elif pref == 1 :
                    pref = 1
                else : 
                    pref = 0 
                loss_1 = -1 * ((pref* th.log(p) + (1-pref)*th.log(1-p)))
                if th.isnan(loss_1) == False : 
                    loss = loss+loss_1
                print(loss_1)
                print(loss)
            
            

            ## Step update
            if args.compare_agents :
                loss2 = th.zeros(1, requires_grad=True).to(device)
                for traj_num,step,rank_1,rank_2,rank_3 in pref_step_data :
                    traj_num = int(traj_num.item())
                    if traj_num != 0 :
                        step = int(step.item())
                        rank = [int(rank_1.item()),int(rank_2.item()),int(rank_3.item())]

                        traj = load_traj(num=traj_num,save_directory=traj_path)
                        r_hat = get_one_step_reward(reward_function,traj,step,n_agents)

                        loss_1 = get_loss(0,1,r_hat,rank)
                        if th.isnan(loss_1) == False : 
                            loss2 = loss2+loss_1
                        loss_1 = get_loss(0,2,r_hat,rank)
                        if th.isnan(loss_1) == False : 
                            loss2 = loss2+loss_1
                        loss_1 = get_loss(1,2,r_hat,rank)
                        if th.isnan(loss_1) == False : 
                            loss2 = loss2+loss_1

                if th.isnan(loss) or th.isnan(loss2) :
                    print('nan')
                    break
                else :
                    opt.zero_grad()
                    print(f'loss : {loss} / loss2 : {loss2}')
                    final_loss = loss + loss2
                    #print(final_loss)
                    #final_loss.requires_grad_(True)
                    print(final_loss)

                    final_loss.backward()
                    opt.step()
                    
                output_loss_1 = loss.item()/900
                output_loss_2 = loss2.item()/(150*3)

                print('{} : {:.5f} / {:.5f}'.format(training_step,output_loss_1,output_loss_2))             
                 
            else :
                if th.isnan(loss):
                    print('nan')
                    break
                else :
                    opt.zero_grad()
                    final_loss = loss
                    print(final_loss)
                    final_loss.requires_grad_(True)
                    print(final_loss)
                    final_loss.backward()
                    opt.step()
                print('{} : {:.5f} '.format(training_step,output_loss_1))             
                
    print(save_model_path)
    th.save(reward_function.state_dict(), save_model_path)
'''

class reward_model(nn.Module) :
    def __init__(self,state_dim,obs_dim,action_dim,hidden_dim=16) :
        super(reward_model, self).__init__()
        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim

        self.rnn_hidden_dim = hidden_dim
        self.input_state = nn.Linear(self.state_dim, self.hidden_dim)
        self.input_next_state = nn.Linear(self.state_dim, self.hidden_dim)

        self.input_obs = nn.Linear(self.obs_dim,self.hidden_dim)
        self.input_action = nn.Linear(self.action_dim, self.hidden_dim)
        

        self.hidden_layer = nn.Linear(self.hidden_dim*4, self.hidden_dim)

        #self.obs_rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
        
        self.output = nn.Linear(self.hidden_dim, 1)
        self.reset_parameters()
        
    def reset_parameters(self):
        for p in self.parameters():
            if p.requires_grad:
                if len(p.shape) > 1:  
                    nn.init.xavier_uniform_(p)
                else: 
                    nn.init.constant_(p, 0)
                    
    def forward(self, s,s_next,o,a) :
        input_state = F.relu(self.input_state(s))
        input_next_state = F.relu(self.input_state(s_next))
        input_obs = F.relu(self.input_obs(o))
        input_one_hot_action = F.relu(self.input_action(a))
        x = th.cat((input_state,input_next_state,input_obs,input_one_hot_action),dim=1)
        hidden = F.relu(self.hidden_layer(x))
        #out = self.output(hidden)
        out = F.tanh(self.output(hidden))
        
        
        return out         
    
class team_reward_model(nn.Module) :
    def __init__(self,state_dim,obs_dim,action_dim,hidden_dim=16,n_agents=3) :
        super(team_reward_model, self).__init__()
        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.n_agents = n_agents
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        
        self.rnn_hidden_dim = hidden_dim
        self.input_state = nn.Linear(self.state_dim, self.hidden_dim)
        self.input_next_state = nn.Linear(self.state_dim, self.hidden_dim)
        
        self.input_obs = nn.Linear(self.obs_dim,self.hidden_dim)
        self.input_action = nn.Linear(self.action_dim*self.n_agents, self.hidden_dim)
        
        self.hidden_layer = nn.Linear(self.hidden_dim*3, self.hidden_dim)
        
        #self.obs_rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)
        
        self.output = nn.Linear(self.hidden_dim, 1)
        self.reset_parameters()
        
        
    def reset_parameters(self):
        for p in self.parameters():
            if p.requires_grad:
                if len(p.shape) > 1:  
                    nn.init.xavier_uniform_(p)
                else: 
                    nn.init.constant_(p, 0)
                    
    def forward(self, s,s_next,a) :
        input_state = F.relu(self.input_state(s))
        input_next_state = F.relu(self.input_state(s_next))
        #input_obs = F.relu(self.input_obs(o))
        input_one_hot_action = F.relu(self.input_action(a))
        x = th.cat((input_state,input_next_state,input_one_hot_action),dim=1)
        hidden = F.relu(self.hidden_layer(x))
        out = self.output(hidden)
        
        return out         