import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.agent_utils import *
from emb.emb import StateEmb


class BCAgent(nn.Module):
    def __init__(self, args, p_type=None, p_idx=None, eval=False):
        super(BCAgent, self).__init__()
        self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        self.player_type = p_type if p_type is not None else args.player_type
        self.player_type_idx = p_idx if p_idx is not None else args.player_type_idx
        self.strength = args.strength
        
        # params
        self.env = args.env_name
        self.player_name = args.player_name if self.env == 'badminton' else None
        if self.env in ['tennis', 'box', 'connect4']:
            self.state_dim = args.model.latent_dim
        else:
            self.state_dim = list(args.state_dim_list)[self.player_type_idx]
        self.action_dim = args.action_dim
        
        # params
        self.hidden_dim = args.model.hidden_dim
        self.agent_num = args.agent_num
        cont = True if self.env in ['badminton'] else False

        self.policy = PolicyModule(state_dim=self.state_dim, 
                                   action_dim=self.action_dim, 
                                   hidden_dim=self.hidden_dim, 
                                   device=self.device,
                                   continuous=cont
                                   ).to(self.device)
        
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=args.train.lr * 30)
        
        if self.env  in ['tennis', 'box', 'connect4']:
            self.emb_model = StateEmb(args).to(self.device) # emb model
            self.load_weights(emb_path=args.emb_path)
            self.emb_model.eval()

        if eval:
            print("## BC Agent eval ##")
            self.load_weights(policy_path=args.policy_path)
            self.policy.eval()
        else:
            print("## BC Agent train ##")


    @torch.no_grad()
    def state_process(self, state, name):
        if self.env in ['tennis', 'box', 'connect4']:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
            state = self.emb_model.state_embed(state, name)
        else:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
        return state.to(self.device)
    
    
    @torch.no_grad()
    def eval_state_process(self, state, name):
        if self.env in ['tennis', 'box', 'connect4']:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
            state = self.emb_model.state_embed(state, name)

        elif self.env == 'badminton':
            state = badminton_process_state(state)

        else:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
        return state.to(self.device)


    def forward(self, state, action):
        state = self.state_process(state, self.player_type)
        recon_probs, pred_action = self.policy(state)

        if self.env in ['badminton']:
            loss = F.mse_loss(recon_probs, action)
        else:
            loss = F.nll_loss(recon_probs.squeeze(-1), action.squeeze(-1))
            
        self.optimizer.zero_grad(set_to_none=True)
        loss.backward()
        self.optimizer.step()

        return loss.item()
    
    
    @torch.no_grad()
    def get_action(self, state, info=None):
        s = self.eval_state_process(state, self.player_type)
        probs, action = self.policy(s)
        
        if self.agent_num == 1:
            probs = probs.squeeze(-1)
            action = action.squeeze(-1)
        
        if self.env in ['connect4', 'holdem']:
            return select_action_with_mask(action, probs, info, use_random=False) ##
            
        if self.env in ['tennis', 'pong', 'box']:
            return action.unsqueeze(0) # tensor
        
        elif self.env == 'badminton':
            return badminton_action_process(probs.squeeze(0), info, state)
        else:
            return action.item() # numpy


    # Load / Save weights
    def save_weights(self, save_weight_path, env_name, datalen):
        path = f'{save_weight_path}/{env_name}'
        if os.path.exists(path) is False:
            os.makedirs(path)
        if self.env in ['badminton']:
            save_path = f'{path}/bc_{self.player_name}_{self.player_type}_{datalen}.pth'
        else:
            save_path = f'{path}/bc_{self.player_type}_{datalen}.pth'
            
        torch.save(self.policy.state_dict(), save_path)
        
        
    def load_weights(self, emb_path=None, policy_path=None):   
        if emb_path:
            e_path = f'{emb_path}/{self.env}/emb.pth'
            checkpoint = torch.load(e_path, weights_only=True)
            self.emb_model.load_state_dict(checkpoint)
        if policy_path:
            if self.env == 'badminton':
                path = f'{policy_path}/{self.env}/bc_{self.player_name}_{self.player_type}_{self.strength}.pth' ##
            else:
                path = f'{policy_path}/{self.env}/bc_{self.player_type}_{self.strength}.pth' ##
            checkpoint = torch.load(path, weights_only=True)
            self.policy.load_state_dict(checkpoint)
        

# module
class PolicyModule(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, device, continuous):
        super(PolicyModule, self).__init__()
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        self.model = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), 
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim), 
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)).to(self.device)
        
        self.output_act = ActionDecoder(hidden_dim, action_dim, continuous=continuous)
        

    def forward(self, state):
        state = state.to(self.device)
        x = self.model(state)
        probs, recon = self.output_act(x)
        
        return probs, recon
    

class ActionDecoder(nn.Module):
    def __init__(self, latent_dim, output_dim, temperature=1.0, action_limit=0.01, continuous=False):
        super(ActionDecoder, self).__init__()
        self.temperature = temperature
        self.action_limit = action_limit
        self.continuous = continuous
        self.fc = nn.Sequential(
                nn.Linear(latent_dim, latent_dim),
                nn.ReLU(),
                nn.Linear(latent_dim, output_dim))
        
        
    def forward(self, x):
        logits = self.fc(x)
        
        if not self.continuous:
            log_probs = F.log_softmax(logits / self.temperature, dim=-1)
            recon = torch.argmax(log_probs, dim=-1)
            return log_probs, recon
        
        else:
            log_probs = F.log_softmax(logits[:12] / self.temperature, dim=-1)
            act = torch.argmax(log_probs, dim=-1)
            return logits, act