import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
import matplotlib.pyplot as plt
from matplotlib import font_manager

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Chinese font display configuration
def set_chinese_font():
    try:
        font_manager.fontManager.addfont("C:/Windows/Fonts/simhei.ttf")
        font_manager.fontManager.addfont("C:/Windows/Fonts/msyh.ttc")
        plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Zen Hei', 'DejaVu Sans']
        plt.rcParams['axes.unicode_minus'] = False
    except FileNotFoundError:
        plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei', 'Arial Unicode MS', 'DejaVu Sans']
        plt.rcParams['axes.unicode_minus'] = False
    except Exception as e:
        plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial']
        plt.rcParams['axes.unicode_minus'] = False
        print(f"Chinese font display configuration failed: {e}")


set_chinese_font()


class GroupMixHPPO:
    def __init__(self, K, N, M):
        self.K = K
        self.N = N  # Save N value for subsequent parameter loading checks
        self.M = M

        if M == 3:
            self.lr = 6e-5
            self.entropy_coef = 0.001  # Initial entropy coefficient, will be dynamically decayed during training
            self.seq_len = 6
            self.gae_lambda = 0.98
            self.ppo_epochs = 6
            self.eps_clip = 0.2
        else:
            self.lr = 4e-5
            self.entropy_coef = 0.001  # Initial entropy coefficient, will be dynamically decayed during training
            self.seq_len = 10
            self.gae_lambda = 0.99
            self.ppo_epochs = 5
            self.eps_clip = 0.2

        self.gamma = 0.99

        self.shared_gru = nn.GRU(input_size=N * 4, hidden_size=64, batch_first=True).to(device)
        self.gru_norm = nn.LayerNorm(64).to(device)
        self._init_weights(self.shared_gru)
        self._init_weights(self.gru_norm)

        self.type_mlps = nn.ModuleList()
        for _ in range(M):
            mlp = nn.Sequential(
                nn.Linear(64, 32),
                nn.Tanh(),
                nn.Dropout(0.1),
                nn.Linear(32, N),  # Output dimension related to number of battlefields N
                nn.Softmax(dim=-1)
            ).to(device)
            self._init_weights(mlp)
            self.type_mlps.append(mlp)

        self.value_net = nn.Sequential(
            nn.Linear(64, 32),
            nn.Tanh(),
            nn.LayerNorm(32),
            nn.Linear(32, 1)
        ).to(device)
        self._init_weights(self.value_net)

        self.optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=2e-5)
        self.obs_history = torch.zeros(K, self.seq_len, N * 4, dtype=torch.float32, device=device)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.GRU)):
            for name, param in module.named_parameters():
                if 'weight' in name:
                    if param.dim() >= 2:
                        nn.init.xavier_normal_(param)
                    else:
                        nn.init.normal_(param, mean=0.0, std=0.01)
                elif 'bias' in name:
                    nn.init.constant_(param, 0.05)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.weight, 1.0)
            nn.init.constant_(module.bias, 0.0)

    def parameters(self):
        params = list(self.shared_gru.parameters())
        params += list(self.gru_norm.parameters())
        params += list(self.type_mlps.parameters())
        params += list(self.value_net.parameters())
        return params

    def named_parameters(self):
        for name, param in self.shared_gru.named_parameters():
            yield f"shared_gru.{name}", param
        for name, param in self.gru_norm.named_parameters():
            yield f"gru_norm.{name}", param
        for i, mlp in enumerate(self.type_mlps):
            for name, param in mlp.named_parameters():
                yield f"type_mlps.{i}.{name}", param
        for name, param in self.value_net.named_parameters():
            yield f"value_net.{name}", param

    def _check_nan(self, tensor, name):
        if torch.any(torch.isnan(tensor)) or torch.any(torch.isinf(tensor)):
            if tensor.dim() == 1:
                tensor = torch.full_like(tensor, 1 / tensor.shape[0])
            else:
                mean_val = tensor[~torch.isnan(tensor)].mean().item()
                tensor = torch.full_like(tensor, mean_val)
        return tensor

    def get_action(self, obs, agent_types):
        obs = obs.to(dtype=torch.float32, device=device)
        obs = self._check_nan(obs, "obs")
        if obs.dim() == 1:
            obs = obs.unsqueeze(0).repeat(self.K, 1)

        self.obs_history = torch.cat([self.obs_history[:, 1:], obs.unsqueeze(1)], dim=1)
        self.obs_history = self._check_nan(self.obs_history, "obs_history")

        h0 = torch.zeros(1, self.K, 64, dtype=torch.float32, device=device)
        gru_out, _ = self.shared_gru(self.obs_history, h0)
        gru_out = self.gru_norm(gru_out)
        gru_out = self._check_nan(gru_out, "gru_out")
        features = gru_out[:, -1, :]

        actions = []
        log_probs = []
        for i in range(self.K):
            agent_type = agent_types[i]
            action_prob = self.type_mlps[agent_type](features[i])
            action_prob = self._check_nan(action_prob, f"action_prob_{i}")
            action_prob = (action_prob * 0.95) + (1 / self.N * 0.05)
            action_prob = action_prob / action_prob.sum(dim=-1, keepdim=True)

            dist = Categorical(action_prob)
            action = dist.sample()
            log_probs.append(dist.log_prob(action))
            actions.append(action)

        value = self.value_net(features).mean()
        value = self._check_nan(value, "value")
        value = torch.clamp(value, -5.0, 5.0)

        return torch.stack(actions), torch.stack(log_probs), value

    def compute_gae(self, rewards, dones, values):
        advantages = []
        last_advantage = 0.0
        eps = 1e-8

        for t in reversed(range(len(rewards))):
            reward_t = rewards[t] if abs(rewards[t]) <= self.N else np.sign(rewards[t]) * self.N
            val_t = values[t] if not np.isnan(values[t]) else (values[t - 1] if t > 0 else 0.0)
            val_t1 = values[t + 1] if (t + 1 < len(values) and not np.isnan(values[t + 1])) else 0.0

            delta = reward_t + self.gamma * (1 - dones[t]) * val_t1 - val_t
            last_advantage = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * last_advantage
            advantages.insert(0, last_advantage)

        advantages = np.array(advantages)
        advantages = (advantages - np.mean(advantages)) / (np.std(advantages) + eps)
        returns = advantages + np.array(values[:-1])

        return torch.tensor(advantages, dtype=torch.float32, device=device), \
            torch.tensor(returns, dtype=torch.float32, device=device)

    def _reset_nan_params(self):
        for name, param in self.named_parameters():
            if torch.any(torch.isnan(param)) or torch.any(torch.isinf(param)):
                if 'weight' in name:
                    if param.dim() >= 2:
                        nn.init.xavier_normal_(param)
                    else:
                        nn.init.normal_(param, mean=0.0, std=0.01)
                elif 'bias' in name:
                    nn.init.constant_(param, 0.05)

    def update(self, old_log_probs, states, actions, rewards, dones, values, agent_types):
        old_log_probs = self._check_nan(old_log_probs, "old_log_probs")
        advantages, returns = self.compute_gae(rewards, dones, values)
        advantages = self._check_nan(advantages, "advantages")
        returns = self._check_nan(returns, "returns")

        total_loss = 0.0
        for _ in range(self.ppo_epochs):
            state = states[-1].to(dtype=torch.float32, device=device)
            state = self._check_nan(state, "update_state")
            if state.dim() == 1:
                state = state.unsqueeze(0).repeat(self.K, 1)

            _, curr_log_probs, curr_value = self.get_action(state, agent_types)
            curr_log_probs = self._check_nan(curr_log_probs, "curr_log_probs")
            curr_value = self._check_nan(curr_value, "curr_value")

            ratio = torch.exp(curr_log_probs - old_log_probs.detach())
            ratio = self._check_nan(ratio, "ratio")
            ratio = torch.clamp(ratio, 0.5, 2.0)
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
            ppo_loss = -torch.min(surr1, surr2).mean()

            value_loss = nn.MSELoss()(curr_value.repeat(len(returns)), returns)
            value_params = list(self.value_net.parameters())
            if value_params:
                torch.nn.utils.clip_grad_norm_(value_params, max_norm=0.1)

            entropy = -torch.mean(torch.exp(curr_log_probs) * curr_log_probs)
            entropy_loss = self.entropy_coef * entropy

            loss = ppo_loss + 0.5 * value_loss - entropy_loss
            loss = self._check_nan(loss, "total_loss")

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=0.15)
            self.optimizer.step()

            self._reset_nan_params()

            total_loss += loss.item()

        avg_loss = total_loss / self.ppo_epochs
        return avg_loss if not np.isnan(avg_loss) else 0.0


class BattleEnvironment:
    def __init__(self, N, M, K, B, T, enemy_strength_ratio=0.9):
        self.N = N
        self.M = M
        self.K = K
        self.B = B
        self.T = T
        self.enemy_strength_ratio = enemy_strength_ratio
        self.enemy_total_strength = int(B * enemy_strength_ratio)

        np.random.seed(42)
        base_probs = [0.4, 0.3, 0.2, 0.1][:M]
        base_probs = np.array(base_probs) / np.sum(base_probs)
        self.agent_types = np.random.choice(M, size=K, p=base_probs)
        np.random.seed(None)

        self.strength_weights = np.array([3.0, 1.5, 0.8, 0.5])[:M]
        self.our_total_strength = np.sum(self.strength_weights[self.agent_types])

        self.reset()

    def reset(self):
        self.enemy_strength = np.ones(self.N) * (self.enemy_total_strength / self.N)
        self.our_deployment = np.zeros((self.K, self.N))
        self.deployed_agents = 0
        self.final_win_count = 0
        self.round_total_net_win = 0
        self.round_total_reward = []
        return self._get_observation()

    def _get_observation(self):
        our_strength_per_battle = np.sum(self.our_deployment * self.strength_weights[self.agent_types, None], axis=0)
        enemy_strength_per_battle = self.enemy_strength
        deployment_ratio_per_battle = np.sum(self.our_deployment, axis=0) / (self.K + 1e-8)
        global_deployment_ratio = self.deployed_agents / (self.K + 1e-8)

        obs = np.concatenate([
            our_strength_per_battle,
            enemy_strength_per_battle,
            deployment_ratio_per_battle,
            np.full(self.N, global_deployment_ratio)
        ])

        obs = np.nan_to_num(obs)
        obs = np.clip(obs, 0.0, self.our_total_strength)
        return torch.FloatTensor(obs, device=device)

    def step(self, actions):
        new_deployed = 0
        for i in range(self.K):
            if np.sum(self.our_deployment[i]) == 0:
                battle_idx = actions[i].item() if isinstance(actions[i], torch.Tensor) else actions[i]
                self.our_deployment[i, battle_idx] = 1
                new_deployed += 1
        self.deployed_agents += new_deployed

        our_strength_per_battle = np.sum(self.our_deployment * self.strength_weights[self.agent_types, None], axis=0)
        win_battles = our_strength_per_battle > self.enemy_strength
        lose_battles = our_strength_per_battle < self.enemy_strength

        done = (self.deployed_agents >= self.K) or (self.deployed_agents >= self.T)
        if done:
            self.final_win_count = int(win_battles.sum())
            self.round_total_net_win = self.final_win_count - int(lose_battles.sum())
            self.round_total_reward.append(self.round_total_net_win)
        else:
            current_net_win = int(win_battles.sum()) - int(lose_battles.sum())
            self.round_total_reward.append(current_net_win)

        step_reward = self.round_total_net_win if done else current_net_win
        win_strength = np.sum(our_strength_per_battle * win_battles)
        resource_util = win_strength / (self.our_total_strength + 1e-8)
        resource_util = np.clip(resource_util, 0.0, 1.0)

        return (self._get_observation(), step_reward, done, resource_util,
                self.final_win_count, self.round_total_net_win, self.round_total_reward)


def train_scenario(scenario_params, prev_agent=None):
    K, N, M, B, T, max_episodes = (
        scenario_params['K'], scenario_params['N'], scenario_params['M'],
        scenario_params['B'], scenario_params['T'], scenario_params['epochs']
    )
    print(f"=== Starting training scenario (K={K}, N={N}, M={M}, B={B}, T={T}) ===")

    agent = GroupMixHPPO(K, N, M)

    # Dynamic decay parameters for entropy coefficient
    initial_entropy_coef = agent.entropy_coef  # Save initial entropy coefficient
    decay_start_episode = 200  # Episode to start decay
    decay_end_episode = 500  # Episode to end decay
    min_entropy_coef = 0.0001  # Minimum entropy coefficient after decay

    if prev_agent is not None and M > prev_agent.M:
        # Only load parameters not affected by N (number of battlefields)
        if hasattr(prev_agent, 'gru_norm') and hasattr(agent, 'gru_norm'):
            agent.gru_norm.load_state_dict(prev_agent.gru_norm.state_dict())

        # Only load MLP parameters when N (number of battlefields) is the same between previous and current scenarios
        # (to avoid output dimension mismatch)
        if hasattr(prev_agent, 'N') and prev_agent.N == N:
            for i in range(min(prev_agent.M, agent.M)):
                if i < len(prev_agent.type_mlps) and i < len(agent.type_mlps):
                    agent.type_mlps[i].load_state_dict(prev_agent.type_mlps[i].state_dict())

        # Initialize new type MLP parameters (reuse existing type parameters)
        for i in range(prev_agent.M, agent.M):
            if i < len(agent.type_mlps) and 1 < len(agent.type_mlps):
                agent.type_mlps[i].load_state_dict(agent.type_mlps[1].state_dict())
                with torch.no_grad():
                    for param in agent.type_mlps[i].parameters():
                        param.data *= 0.6

    env = BattleEnvironment(N, M, K, B, T)
    log = {
        'avg_10round_reward': [],
        'win_count': [],
        'net_win_rate': [],
        'resource_util': [],
        'loss': []
    }
    reward_10round_window = []

    print("Episode | Avg Reward | Win Count | Avg Net Win Rate | Resource Utilization | Loss")

    for episode in range(1, max_episodes + 1):
        # Dynamically adjust entropy coefficient based on current episode
        if episode >= decay_start_episode and episode <= decay_end_episode:
            # Calculate decay ratio (between 0 and 1)
            decay_ratio = (episode - decay_start_episode) / (decay_end_episode - decay_start_episode)
            # Linearly decay entropy coefficient
            agent.entropy_coef = initial_entropy_coef - (initial_entropy_coef - min_entropy_coef) * decay_ratio
        elif episode > decay_end_episode:
            # Maintain minimum entropy coefficient after decay ends
            agent.entropy_coef = min_entropy_coef

        obs = env.reset()
        states = []
        actions = []
        log_probs = []
        rewards = []
        dones = []
        values = []

        total_resource_util = 0.0
        step = 0
        agent_types = torch.tensor(env.agent_types, dtype=torch.long, device=device)

        while True:
            step += 1
            action, log_prob, value = agent.get_action(obs, agent_types)
            next_obs, reward, done, util, win_count, round_total_net_win, round_total_reward = env.step(action)

            states.append(obs)
            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(reward)
            dones.append(done)
            values.append(value.item())

            total_resource_util += util
            obs = next_obs

            if done:
                break

        _, _, final_value = agent.get_action(obs, agent_types)
        values.append(final_value.item())

        loss = 0.0
        if len(log_probs) > 0 and len(states) > 0:
            loss = agent.update(
                old_log_probs=torch.stack(log_probs),
                states=states,
                actions=torch.stack(actions),
                rewards=rewards,
                dones=dones,
                values=values,
                agent_types=agent_types
            )

        current_round_avg_reward = np.mean(round_total_reward) if round_total_reward else 0.0
        reward_10round_window.append(current_round_avg_reward)
        if len(reward_10round_window) > 10:
            reward_10round_window.pop(0)
        avg_10round_reward = np.mean(reward_10round_window) if reward_10round_window else 0.0

        current_win_count = win_count
        current_net_win_rate = round_total_net_win / N if N > 0 else 0.0
        avg_resource_util = total_resource_util / step if step > 0 else 0.0

        log['avg_10round_reward'].append(avg_10round_reward)
        log['win_count'].append(current_win_count)
        log['net_win_rate'].append(current_net_win_rate)
        log['resource_util'].append(avg_resource_util)
        log['loss'].append(loss)

        if episode % 10 == 0:
            print(
                f"Episode {episode:3d} | Avg Reward: {avg_10round_reward:6.2f} | Win Count: {current_win_count:3d} | Avg Net Win Rate: {current_net_win_rate:10.3f} | Resource Utilization: {avg_resource_util:8.3f} | Loss: {loss:.6f}")

    plt.figure(figsize=(18, 10))
    plt.subplot(2, 2, 1)
    plt.plot(log['avg_10round_reward'], color='#1f77b4')
    plt.title('10-Episode Average Reward')
    plt.xlabel('Training Episodes')
    plt.ylabel('Average Reward')
    plt.grid(True, alpha=0.3)

    plt.subplot(2, 2, 2)
    plt.plot(log['win_count'], color='#ff7f0e')
    plt.title('Win Count per Episode')
    plt.xlabel('Training Episodes')
    plt.ylabel('Win Count')
    plt.grid(True, alpha=0.3)

    plt.subplot(2, 2, 3)
    plt.plot(log['net_win_rate'], color='#2ca02c')
    plt.title('Average Net Win Rate')
    plt.xlabel('Training Episodes')
    plt.ylabel('Net Win Rate')
    plt.grid(True, alpha=0.3)

    plt.subplot(2, 2, 4)
    ax1 = plt.gca()
    ax2 = ax1.twinx()
    ax1.plot(log['resource_util'], color='#d62728', label='Resource Utilization')
    ax2.plot(log['loss'], color='#9467bd', label='Loss', alpha=0.7)
    ax1.set_xlabel('Training Episodes')
    ax1.set_ylabel('Resource Utilization', color='#d62728')
    ax2.set_ylabel('Loss', color='#9467bd')
    ax1.grid(True, alpha=0.3)
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')

    plt.tight_layout()
    plt.savefig(f"scenario_K{K}_N{N}_train_log.png", dpi=100)
    plt.close()

    return agent


if __name__ == "__main__":
    scenarios = [
        {'K': 50, 'N': 5, 'M': 3, 'B': 100, 'T': 100, 'epochs': 1000},
        {'K': 200, 'N': 10, 'M': 3, 'B': 400, 'T': 100, 'epochs': 1000},
        {'K': 500, 'N': 15, 'M': 4, 'B': 1000, 'T': 100, 'epochs': 1000},
        {'K': 1000, 'N': 20, 'M': 4, 'B': 2000, 'T': 100, 'epochs': 1000}
    ]

    prev_agent = None
    for scenario in scenarios:
        prev_agent = train_scenario(scenario, prev_agent)
        print(f"=== Scenario (K={scenario['K']}) training completed ===")
        print(
            f"- Training plot for scenario (K={scenario['K']},N={scenario['N']}): scenario_K{scenario['K']}_N{scenario['N']}_train_log.png")
    print("- Combined performance comparison plot for all four scenarios: combined_scenarios_performance.png")
