import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.agent_utils import *
from utils.plot import LamCollector
from emb.emb import StateEmb
from demo.ddpm import DDPM



class Agent:
    def __init__(self, args, p_type=None, p_idx=None, player_name=None):
        print('## DDGI agent eval ##')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.env = args.env_name
        self.action_dim = args.action_dim
        if self.env == 'badminton':
            self.player_name = player_name if player_name is not None else args.player_name
            print(f"=== Player name: {self.player_name} === ")
            self.player_type = 'self' if player_name is not None else args.player_type
            self.player_type_idx = 0 if player_name is not None else args.player_type_idx
        else:
            self.player_name = None
            self.player_type = p_type if p_type is not None else args.player_type
            self.player_type_idx = p_idx if p_idx is not None else args.player_type_idx
        
        self.use_ddgi = args.use_ddgi
        self.player_strength = args.strength
        self.traj_length = args.traj_length
        self.d_weight = args.model.d_weight
        self.sample_alpha = args.model.sample_alpha
        
        # save lam
        self.save_lam_path = args.eval.data_path
        self.analyer = LamCollector(t=args.model.sample_steps)

        # model
        if self.env in ['tennis', 'box', 'connect4']:
            self.emb_model = StateEmb(args).to(self.device)
            self.load_weights(emb_path = args.emb_path)
            self.emb_model.eval()
        else:
            self.emb_model = None

        x_max = torch.ones((1, self.action_dim), device=self.device) * args.model.x_max # 3.0
        x_min = torch.ones((1, self.action_dim), device=self.device) * args.model.x_min # 3.0
        self.model = DDPM(args, x_max=x_max, x_min=x_min).to(self.device)

        # load weight
        self.load_weights(policy_path = args.policy_path)
        self.model.eval()
            
        print("=====================================================")
        print(f"Player Idx: {self.player_type_idx}, Player: {self.player_type}")
        if self.player_name is not None: print(f"Player name: {self.player_name}")
        print("=====================================================")
    
    
    @torch.no_grad()
    def state_process(self, state, name):
        if self.env in ['tennis', 'box', 'connect4']:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
            state = self.emb_model.state_embed(state, name).to(self.device)

        elif self.env == 'badminton':
            state = badminton_process_state(state)

        else:
            state = state if isinstance(state, torch.Tensor) else torch.tensor(state, dtype=torch.float32)
        return state.to(self.device)


    @torch.no_grad()
    def get_action(self, state, info):
        cond = {'state': self.state_process(state, self.player_type)}
        prior = torch.zeros((1, self.action_dim), device=self.device)
        act_one_hot, lams = self.model.sample(prior=prior, condition=cond, d_weight=self.d_weight, sample_alpha=self.sample_alpha, ddgi=self.use_ddgi)
        self.analyer.add_sequence(lams)
        
        prob = torch.softmax(act_one_hot, dim=-1)
        action = torch.argmax(prob, dim=-1)

        # save lam
        """if info['episodes'] % 20 == 0:
            with open(f"{self.save_lam_path}/{self.env}_lam.txt", "a") as f:
                f.write(",".join([f"{v:.4f}" for v in lams]) + "\n")"""
        
        if self.env in ['tennis', 'box']:
            return action
        elif self.env in ['connect4', 'holdem']:
            return select_action_with_mask(action, prob.squeeze(0), info, use_random=False) ##
        elif self.env in ['badminton']:
            return badminton_action_process(act_one_hot.squeeze(0), info, state) ##
        else:
            return action.item()
    
    
    def ouptut_lam(self):
        name = f"evaluation/plot/{self.env}/lam_a{self.sample_alpha}_d{self.d_weight}"
        mean, std = self.analyer.summarize_and_plot(plot_path=f"{name}.png", csv_path=f"{name}.csv")
        print(mean)
        print(std)
    
    # ========================== Load weight ===============================
    def load_weights(self, emb_path=None, policy_path=None):
        if emb_path:
            path = f'{emb_path}/{self.env}/emb.pth'
            checkpoint = torch.load(path, weights_only=True)
            self.emb_model.load_state_dict(checkpoint)
             
        if policy_path:
            if self.player_name is None:
                path = f"{policy_path}/{self.env}/{self.player_type}_policy_{self.player_strength}.pth"
                self.model.load_weights(path)
            else:
                path = f"{policy_path}/{self.env}/{self.player_name}_policy_{self.player_strength}.pth"
                self.model.load_weights(path)