import pickle
import torch
import numpy as np
from datetime import datetime
from env.badminton.env import BadmintonEnv
from model.baseline.BC.bc import BCAgent
from omegaconf import OmegaConf
from hydra.utils import to_absolute_path
from env_utils import *
from load_agent import Load_model
from badminton.Agent.RallyNet import RallyNet


"""
1: "Serve short 發短球", 2: "Serve long 發長球", 3: "Clear 長球", 4: "Smash 殺球", 5: "Drop 切球", 
6: "Lob 挑球", 7: "Drive 平球", 8: "Net Shot 網前球", 9: "Push Shot 推撲球", 10: "Smash Defence 接殺防守", 
11: "Missed shot 接不到"
"""

def main(args, is_single_agent=True):
    # ---------------------------------------------------
    # Agent
    model_list = []
    if args.eval.IsAgent:
        player_agent = Load_model(args, is_single_agent)        
        model_list.append(args.model_name)
    # ---------------------------------------------------
        
    # RallyNet
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    with open('./env/badminton/weight/target_players_ids.pkl', 'rb') as f:
        target_players = pickle.load(f)
    opp_name = args.eval.opponent
    opp_id = target_players.index(opp_name)
    with open('./env/badminton/weight/hyperparameters.pkl', 'rb') as f:
        hyperparameters = pickle.load(f)
    opponent_agent = RallyNet(
            data_size= hyperparameters['data_size'], 
            latent_size=hyperparameters['latent_size'],
            context_size=hyperparameters['context_size'],
            hidden_size=hyperparameters['hidden_size'],
            player_ids_len = hyperparameters['player_ids_len'],
            target_players = hyperparameters['target_players'],
            shot_type_len = hyperparameters['shot_type_len'],
            ts = hyperparameters['ts'],
            id = opp_id,
            device = device,
            player_name = args.eval.opponent,
        ).to(device)
    opponent_agent.load_state_dict(torch.load("./env/badminton/weight/gen_e_1000.trc", map_location = device, weights_only=True),strict = False)

            
    if is_single_agent or not args.eval.IsAgent:
        model_list.append(f'RallyNet: {args.eval.opponent}')
        
                    
    # ================ Specify rally length =============== #
    rallies = args.eval.episodes
    rounds = args.eval.rounds

    # ================ Construct the environment for interaction =============== #
    env = BadmintonEnv(player_agent, opponent_agent, 
                       args.player_name, args.eval.opponent, 
                       rallies, is_match = True, is_constraint = True, is_have_serve_state = False, 
                       filepath = f'./evaluation/data/badminton/{args.model_name}_{args.player_name}_vs_{args.eval.opponent}_{args.eval.rounds}.csv')
    
        
    # ---------------- log prepare ----------------
    self_agent, opp_agent = args.player_name, args.eval.opponent
    scores = {self_agent: [], opp_agent: []}
    score_diffs = []
    round_score_stats = {self_agent: [], opp_agent: []}
    round_lost_stats = {self_agent: [], opp_agent: []}
    win_counts = {self_agent: 0, opp_agent: 0}
    round_counts = []
    
    
    log_lines = []
    today_str = datetime.now().strftime("%Y-%m-%d")
    log_lines.append(f"Data: {today_str}")
    log_lines.append(f"Total rounds: {args.eval.rounds} with per {args.eval.episodes} episodes")
    log_lines.append(f"Model: {model_list}")
    log_lines.append(f"Player: [{self_agent}] vs. [{opp_agent}]")
    log_lines.append("\n" + "*" * 50 + "\n")


    # ================ Competition starts =============== #
    for r in range(1, args.eval.rounds+1):
        rewards = []
        score_self, score_opp = [], []
        steps = []
        
        print(f'\n====================== round {r} ======================\n')
        
        log_lines.append("-" + f"round {r}" + "-")
        prev_self_score, prev_opp_score = 0, 0
        
        for rally in range(1, rallies + 1):
            states, info, done, launch = env.reset()
            print("rally :", rally)
            print("match: ", info["match"])
            
            if info["match"] != r:
                break

            reset = 0
            while not done :
                #print('state: ', states)
                action = player_agent.get_action(states, info)
                #print('action: ', action)
                if launch:
                    action = insert_action_coord(action, states[1])
                else:
                    action = insert_action_coord(action, states[3])
                #print('action: ', action)
                states, reward, info, done, launch = env.step(action, launch)
                if info['round'][-1] >= args.traj_length:
                    break
            
            print("score: ", info['env_score'])
            print("round: ", info['round'][-1]-1)
            score_self.append(info['env_score'][0] - prev_self_score)
            score_opp.append(info['env_score'][1] - prev_opp_score)
            steps.append(info['round'][-1]-1)
            prev_self_score, prev_opp_score = info['env_score'][0], info['env_score'][1]
            print()
        # ================ Close the environment and save the result =============== #
        
        #env.close()
        #return list(env.output.to_dict(orient="records"))
        
        # ---------------------- end one round ----------------------
        print(f"\n------- round {r} is end -------\n")
        print(f'\n====================================================\n')
        
        rewards = np.array(rewards)
        score_s = np.array(score_self)
        score_o = np.array(score_opp)
        rounds = np.array(steps)

        total_score_s = int(score_s.sum())
        total_score_o = int(score_o.sum())

        winner = self_agent if total_score_s > total_score_o else opp_agent
        score_diff = abs(total_score_s - total_score_o)

        round_on_score_1 = rounds[score_s == 1]
        round_on_score_2 = rounds[score_o == 1]
        round_on_lost_1 = rounds[score_o == 1]
        round_on_lost_2 = rounds[score_s == 1]

        scores[self_agent].append(total_score_s)
        scores[opp_agent].append(total_score_o)
        score_diffs.append(score_diff)
        round_counts.append(len(rounds))

        round_score_stats[self_agent].append(round_stats(round_on_score_1))
        round_lost_stats[self_agent].append(round_stats(round_on_lost_1))
        round_score_stats[opp_agent].append(round_stats(round_on_score_2))
        round_lost_stats[opp_agent].append(round_stats(round_on_lost_2))

        win_counts[winner] += 1

        log_lines.append(f"  get point: {self_agent}= {total_score_s}, {opp_agent}= {total_score_o}")
        log_lines.append(f"  score diff: {score_diff}")
        log_lines.append(f"  rally len: {len(rounds)}")
        """for agent in [self_agent, opp_agent]:
            r_s = round_score_stats[agent]
            r_l = round_lost_stats[agent]
            log_lines.append(f"  [{agent}] steps when get points: avg= {r_s[-1]['avg']:.2f}, min= {r_s[-1]['min']}, max= {r_s[-1]['max']}")
            log_lines.append(f"  [{agent}] steps when lose points: avg= {r_l[-1]['avg']:.2f}, min= {r_l[-1]['min']}, max= {r_l[-1]['max']}")
        log_lines.append("")"""

        log_lines.append("-" * 50)
    
    env.close()
    
    # ---------------------- Total Save ----------------------
    print(f'\n====================================================\n')
    log_lines.append("\n=== Total statistic ===\n")
    
    n = args.eval.rounds
    for agent in [self_agent, opp_agent]:
        log_lines.append(f"{agent}:")
        log_lines.append(f"  win num: {win_counts[agent]} / {n}")
        log_lines.append(f"  win rate: {win_counts[agent]/n:.3f}")
        s = stat_summary(scores[agent])
        log_lines.append(f"  score: avg= {s['avg']:.2f}, min= {s['min']}, max= {s['max']}")
        r = round_stat_summary(round_score_stats[agent])
        log_lines.append(f"  Number of steps when get points: avg= {r['avg']:.2f}, min= {r['min']}, max= {r['max']}")
        r = round_stat_summary(round_lost_stats[agent])
        log_lines.append(f"  Number of steps when losing points: avg= {r['avg']:.2f}, min= {r['min']}, max= {r['max']}")
        log_lines.append("")

    diff_stat = stat_summary(score_diffs)
    log_lines.append(f"   Score difference statistics: avg= {diff_stat['avg']:.2f}, min= {diff_stat['min']}, max= {diff_stat['max']}")

    round_stat = stat_summary(round_counts)
    log_lines.append(f"   Number of rounds per game: avg= {round_stat['avg']:.2f}, min= {round_stat['min']}, max= {round_stat['max']}")
    log_lines.append(f'\n====================================================\n')
    

    if args.eval.IsAgent:
        if is_single_agent:
            save_name = f"{args.model_name}_{args.player_type}_s{args.strength}_{args.eval.rounds}.txt"
        else:
            save_name = f"{args.model_name}_all_s{args.strength}_{args.eval.rounds}.txt"
    else:
        save_name = f"origin_{args.eval.episodes}.txt"
        
    log_text = "\n".join(log_lines)
    save_path = f"{args.eval.result_path}/{args.env_name}"
    save_log(log_text, save_name, save_path)
        