import re
import torch
import numpy as np
from accelerate import Accelerator
from tabulate import tabulate
from dlm.trainer import ModelConfig, EnvConfig
from dlm.data_utils import ChatMLProcessor
from dlm.models.utils import chat_template
from dlm.trainer.utils import get_quantization_config, get_kbit_device_map
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser


def extract_latest_assistant_reply(text: str) -> str:
    match = re.search(
        r"<\|start_header_id\|>assistant<\|end_header_id\|>\s*\n*(.*?)<\|eot_id\|>",
        text,
        re.DOTALL
    )
    if match:
        return match.group(1).strip()
    else:
        return "[No assistant reply found]"


def infer_actions(model, tokenizer, obs_messages, history_messages, action_list, avail_actions, max_new_tokens=16,
                  max_attempts=2):
    messages = [h + o for h, o in zip(history_messages, obs_messages)]
    input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    inputs = tokenizer(
        input_text,
        truncation=True,
        max_length=1024,
        padding=True,
        return_tensors="pt"
    ).to(model.device)

    actions = np.zeros(len(messages), dtype=int)
    ood_count = 0
    random_action_count = 0

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            top_p=1.0,
            pad_token_id=tokenizer.eos_token_id,
        )

        generated_tokens = outputs[:, inputs["input_ids"].shape[-1]:]
        output_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
        replies = [extract_latest_assistant_reply(text) for text in output_text]

    need_regeneration = []
    for i, reply in enumerate(replies):
        reply_clean = reply.strip().rstrip('.')
        try:
            action_id = action_list.index(reply_clean)
            if avail_actions[i][action_id] == 1:
                actions[i] = action_id
            else:
                need_regeneration.append(i)
                ood_count += 1
                print(f"[Agent {i}] Action '{action_list[action_id]}' is invalid. Will try regenerating.")
        except ValueError:
            need_regeneration.append(i)
            ood_count += 1
            print(f"[Agent {i}] Generated '{reply_clean}' not in action list. Will try regenerating.")

    if need_regeneration and max_attempts > 1:
        regen_messages = [messages[i] for i in need_regeneration]
        regen_input_text = tokenizer.apply_chat_template(regen_messages, add_generation_prompt=True, tokenize=False)

        regen_inputs = tokenizer(
            regen_input_text,
            truncation=True,
            max_length=1024,
            padding=True,
            return_tensors="pt"
        ).to(model.device)

        with torch.no_grad():
            num_candidates = 4
            regen_outputs = model.generate(
                **regen_inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=1.2,
                top_p=0.95,
                top_k=50,
                num_return_sequences=num_candidates,
                pad_token_id=tokenizer.eos_token_id,
            )

            batch_size = len(need_regeneration)
            regen_outputs = regen_outputs.reshape(batch_size, num_candidates, -1)

            for idx, agent_idx in enumerate(need_regeneration):
                valid_action_found = False

                for candidate_idx in range(num_candidates):
                    candidate_tokens = regen_outputs[idx, candidate_idx, regen_inputs["input_ids"].shape[-1]:]
                    candidate_text = tokenizer.decode(candidate_tokens, skip_special_tokens=False)
                    reply = extract_latest_assistant_reply(candidate_text)
                    reply_clean = reply.strip().rstrip('.')

                    try:
                        action_id = action_list.index(reply_clean)
                        if avail_actions[agent_idx][action_id] == 1:
                            actions[agent_idx] = action_id
                            valid_action_found = True
                            print(
                                f"[Agent {agent_idx}] Found valid action '{action_list[action_id]}' in candidate {candidate_idx + 1}/{num_candidates}")
                            break
                        else:
                            print(
                                f"[Agent {agent_idx}] Candidate {candidate_idx + 1}/{num_candidates} '{action_list[action_id]}' is invalid")
                    except ValueError:
                        print(
                            f"[Agent {agent_idx}] Candidate {candidate_idx + 1}/{num_candidates} '{reply_clean}' not in action list")

                if not valid_action_found:
                    avail_actions_ind = np.nonzero(avail_actions[agent_idx])[0]
                    actions[agent_idx] = np.random.choice(avail_actions_ind)
                    random_action_count += 1
                    print(
                        f"[Agent {agent_idx}] All {num_candidates} candidates invalid. Using fallback: '{action_list[actions[agent_idx]]}'")

    for i in range(len(actions)):
        if i in need_regeneration and avail_actions[i][actions[i]] == 0:
            avail_actions_ind = np.nonzero(avail_actions[i])[0]
            actions[i] = np.random.choice(avail_actions_ind)
            random_action_count += 1
            print(f"[Agent {i}] Using random fallback action: '{action_list[actions[i]]}'")

    return actions, ood_count, random_action_count


def run_episode(model, tokenizer, inference, action_list):
    inference.reset()
    steps = 0
    terminated = False
    episode_reward = 0
    history_messages = [[] for _ in range(inference.args.n_agents)]
    win = 0
    episode_ood_count = 0
    episode_random_count = 0

    while not terminated:
        obs = inference.get_obs()
        obs_messages = inference.infer(obs, None, steps)
        avail_actions = inference.get_avail_actions()
        actions, step_ood_count, step_random_count = infer_actions(model, tokenizer, obs_messages, history_messages,
                                                                   action_list, avail_actions)
        episode_ood_count += step_ood_count
        episode_random_count += step_random_count
        reward, terminated, info = inference.step(actions)
        messages = inference.infer(obs, actions, steps)
        history_messages = [h + m for h, m in zip(history_messages, messages)]
        episode_reward += reward
        steps += 1
        print("Reward: {:.4f}, Actions: {}, Steps: {}, Terminated: {}, Win: {}".format(
            reward, actions, steps, terminated, info.get("battle_won") if terminated else "-"
        ))
        if terminated and info.get("battle_won"):
            win = 1
    print(f"[Episode OOD Count] {episode_ood_count}, [Random Actions] {episode_random_count}")
    return episode_reward, win, episode_ood_count, episode_random_count, steps


def main(model_args, env_args, test_runs=5, episodes_per_run=100):
    accelerator = Accelerator()
    quant_config = get_quantization_config(model_args)
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=model_args.torch_dtype,
        use_cache=True,
        device_map=get_kbit_device_map() if quant_config else None,
        quantization_config=quant_config,
    )

    token = "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"
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path,
                                              padding_side="left",
                                              truncation_side="left",
                                              trust_remote_code=True,
                                              use_fast=True,
                                              token=token)
    model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, token=token, **model_kwargs)
    if not quant_config:
        model = accelerator.prepare(model)
    model.eval()
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.chat_template is None:
        tokenizer.chat_template = chat_template

    summary_table = []
    total_ood_count = 0
    total_random_count = 0
    total_steps = 0
    win_rates = []

    for run in range(test_runs):
        print(f"\n=== Test Run {run + 1}/{test_runs} ===")
        inference = ChatMLProcessor(env_args.map_name)
        action_list = inference.get_action_list()

        win_count = 0
        run_ood_count = 0
        run_random_count = 0
        run_total_steps = 0

        for e in range(episodes_per_run):
            reward, win, ood_count, random_count, steps = run_episode(model, tokenizer, inference, action_list)
            win_count += win
            run_ood_count += ood_count
            run_random_count += random_count
            run_total_steps += steps
            print(
                f"Episode {e + 1}: Reward = {reward:.2f}, Win = {bool(win)}, OOD Count = {ood_count}, Random Actions = {random_count}, Steps = {steps}")

        inference.close()

        win_rate = win_count / episodes_per_run
        win_rates.append(win_rate)
        total_ood_count += run_ood_count
        total_random_count += run_random_count
        total_steps += run_total_steps

        summary_table.append([
            f"{run + 1}",
            f"{win_rate:.3f}",
            f"{run_ood_count} / {run_total_steps} steps",
            f"{run_random_count} ({(run_random_count / run_total_steps * 100):.2f}%)"
        ])

    avg_win_rate = np.mean(win_rates)
    std_win_rate = np.std(win_rates)
    random_action_rate = total_random_count / total_steps * 100 if total_steps > 0 else 0

    summary_table.append([
        "Avg",
        f"{avg_win_rate:.3f}",
        f"{total_ood_count} / {total_steps} steps",
        f"{total_random_count} ({random_action_rate:.2f}%)"
    ])
    summary_table.append([
        "Std",
        f"{std_win_rate:.3f}",
        "",
        ""
    ])

    print("\n==== Final Summary ====")
    print(tabulate(summary_table, headers=["Run", "Win Rate", "Total OOD Count", "Random Actions (%)"],
                   tablefmt="pretty"))


if __name__ == "__main__":
    parser = HfArgumentParser((ModelConfig, EnvConfig))
    m_args, e_args = parser.parse_args_into_dataclasses()
    main(m_args, e_args)