{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Reinforcement Learning Multi-Agent System for Modeling Demographic Policies in Russian Regions.\n",
    "CHANGES:\n",
    "1. The number of training epochs has been increased to 600 for all nine experiments.\n",
    "2. Data from the year 2000 onward is utilized.\n",
    "3. Experiment 0 (MADDPG-EVO-0) remains unchanged.\n",
    "4. Modifications have been introduced into the remaining experiments to enhance stability and incorporate real-world data.\n",
    "\"\"\"\n",
    "\n",
    "# Import libraries for numerical computations and data processing\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "# Import PyTorch libraries for neural network construction and training\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "# Import data structures and random number generation module\n",
    "from collections import defaultdict, deque\n",
    "import random\n",
    "# Import modules for JSON handling and plotting\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "# Import datetime module\n",
    "from datetime import datetime\n",
    "# Import warning management module\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# Set random seeds for reproducibility\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "random.seed(42)\n",
    "\n",
    "# Configure matplotlib for proper rendering of non-ASCII text\n",
    "plt.rcParams['font.family'] = 'DejaVu Sans'\n",
    "plt.rcParams['axes.unicode_minus'] = False\n",
    "\n",
    "# --- BLOCK 1: Processing and Generation of Demographic Data ---\n",
    "class DemographicDataProcessor:\n",
    "    \"\"\"Processor for demographic data based on real regional statistics.\"\"\"\n",
    "    def __init__(self, real_data_path, crisis_scenarios_path):\n",
    "        \"\"\"Initialize the data processor.\"\"\"\n",
    "        self.df = pd.read_csv(real_data_path)\n",
    "        # Load crisis scenarios from JSON file\n",
    "        with open(crisis_scenarios_path, 'r') as f:\n",
    "            self.crisis_scenarios = json.load(f)\n",
    "        self.preprocess_data()\n",
    "\n",
    "    def preprocess_data(self):\n",
    "        \"\"\"Preprocess real data for simulation.\"\"\"\n",
    "        # Fill missing values via forward and backward filling\n",
    "        self.df = self.df.fillna(method='ffill').fillna(method='bfill')\n",
    "        # Compute additional indicators where data is absent\n",
    "        self.df.loc[self.df['natural_increase_rate'].isna(), 'natural_increase_rate'] = \\\n",
    "            self.df['birth_rate'] - self.df['death_rate']\n",
    "        # Normalize statistics by region\n",
    "        self.region_stats = {}\n",
    "        for region in self.df['region_name'].unique():\n",
    "            region_data = self.df[self.df['region_name'] == region]\n",
    "            # Compute statistical metrics per region\n",
    "            self.region_stats[region] = {\n",
    "                'birth_rate_mean': region_data['birth_rate'].mean(),\n",
    "                'birth_rate_std': region_data['birth_rate'].std(),\n",
    "                'death_rate_mean': region_data['death_rate'].mean(),\n",
    "                'death_rate_std': region_data['death_rate'].std(),\n",
    "                'migration_mean': region_data['migration_balance'].mean(),\n",
    "                'migration_std': region_data['migration_balance'].std(),\n",
    "                'gdp_mean': region_data['gdp_per_capita'].mean(),\n",
    "                'gdp_std': region_data['gdp_per_capita'].std(),\n",
    "                'unemployment_mean': region_data['unemployment_rate'].mean(),\n",
    "                'unemployment_std': region_data['unemployment_rate'].std(),\n",
    "                'population_trend': self._calculate_trend(region_data['population']),\n",
    "                'region_id': region_data['region_id'].iloc[0]\n",
    "            }\n",
    "\n",
    "    def _calculate_trend(self, series):\n",
    "        \"\"\"Compute the linear trend of a time series.\"\"\"\n",
    "        series = series.dropna()\n",
    "        if len(series) < 2:\n",
    "            return 0\n",
    "        x = np.arange(len(series))\n",
    "        z = np.polyfit(x, series, 1)\n",
    "        return z[0]\n",
    "\n",
    "    def apply_crisis_impact(self, base_values, crisis_scenario, year, crisis_start_year):\n",
    "        \"\"\"Apply crisis impact on demographic indicators.\"\"\"\n",
    "        crisis_duration = year - crisis_start_year + 1\n",
    "        max_duration = crisis_scenario['end_year'] - crisis_scenario['start_year'] + 1\n",
    "        # Crisis intensity decays over time\n",
    "        intensity = max(0, 1 - (crisis_duration - 1) / max_duration)\n",
    "        impacts = crisis_scenario['demographic_impacts']\n",
    "        # Compute impacts on each indicator\n",
    "        birth_rate_impact = impacts['birth_rate_change'] * intensity\n",
    "        death_rate_impact = impacts['death_rate_change'] * intensity\n",
    "        migration_impact = impacts['migration_change'] * intensity\n",
    "        economic_impact = impacts['economic_impact'] * intensity\n",
    "        modified_values = base_values.copy()\n",
    "        # Apply impacts to indicators\n",
    "        modified_values['birth_rate'] *= (1 + birth_rate_impact)\n",
    "        modified_values['death_rate'] *= (1 + death_rate_impact)\n",
    "        modified_values['migration_balance'] *= (1 + migration_impact)\n",
    "        modified_values['gdp_per_capita'] *= (1 + economic_impact)\n",
    "        modified_values['unemployment_rate'] *= (1 - economic_impact * 0.5)\n",
    "        return modified_values\n",
    "\n",
    "    def generate_training_data(self, years, regions, apply_crisis=True):\n",
    "        \"\"\"Generate training data based on real-world observations.\"\"\"\n",
    "        training_data = []\n",
    "        for year in years:\n",
    "            for region in regions:\n",
    "                if region not in self.region_stats:\n",
    "                    continue\n",
    "                # Retrieve real data if available\n",
    "                real_data = self.df[(self.df['region_name'] == region) & (self.df['year'] == year)]\n",
    "                if len(real_data) > 0:\n",
    "                    # Use observed data\n",
    "                    base_values = {\n",
    "                        'region_id': real_data['region_id'].iloc[0],\n",
    "                        'region_name': region,\n",
    "                        'year': year,\n",
    "                        'birth_rate': real_data['birth_rate'].iloc[0],\n",
    "                        'death_rate': real_data['death_rate'].iloc[0],\n",
    "                        'migration_balance': real_data['migration_balance'].iloc[0],\n",
    "                        'gdp_per_capita': real_data['gdp_per_capita'].iloc[0],\n",
    "                        'unemployment_rate': real_data['unemployment_rate'].iloc[0],\n",
    "                        'population': real_data['population'].iloc[0],\n",
    "                        'average_wage': real_data['average_wage'].iloc[0]\n",
    "                    }\n",
    "                else:\n",
    "                    # Generate synthetic data using statistical profiles\n",
    "                    stats = self.region_stats[region]\n",
    "                    base_values = {\n",
    "                        'region_id': stats['region_id'],\n",
    "                        'region_name': region,\n",
    "                        'year': year,\n",
    "                        'birth_rate': max(0, np.random.normal(stats['birth_rate_mean'], stats['birth_rate_std'])),\n",
    "                        'death_rate': max(0, np.random.normal(stats['death_rate_mean'], stats['death_rate_std'])),\n",
    "                        'migration_balance': np.random.normal(stats['migration_mean'], stats['migration_std']),\n",
    "                        'gdp_per_capita': max(0, np.random.normal(stats['gdp_mean'], stats['gdp_std'])),\n",
    "                        'unemployment_rate': max(0, min(100, np.random.normal(stats['unemployment_mean'], stats['unemployment_std'])))\n",
    "                    }\n",
    "                # Apply crisis scenarios if enabled\n",
    "                if apply_crisis:\n",
    "                    for scenario in self.crisis_scenarios:\n",
    "                        if scenario['start_year'] <= year <= scenario['end_year']:\n",
    "                            base_values = self.apply_crisis_impact(\n",
    "                                base_values, scenario, year, scenario['start_year']\n",
    "                            )\n",
    "                # Compute derived indicators\n",
    "                base_values['natural_increase_rate'] = base_values['birth_rate'] - base_values['death_rate']\n",
    "                # Estimate population based on trend\n",
    "                if 'population' not in base_values:\n",
    "                    base_population = 1000000 + stats['population_trend'] * (year - 2010)\n",
    "                    base_values['population'] = max(0, int(base_population +\n",
    "                        (base_values['natural_increase_rate'] + base_values['migration_balance']/1000) * 1000))\n",
    "                # Estimate average wage\n",
    "                if 'average_wage' not in base_values or pd.isna(base_values['average_wage']):\n",
    "                    base_values['average_wage'] = max(0, base_values['gdp_per_capita'] * 0.03 *\n",
    "                        (1 - base_values['unemployment_rate']/100))\n",
    "                training_data.append(base_values)\n",
    "        return pd.DataFrame(training_data)"
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "# --- BLOCK 2: Definition of the Learning Environment (DemographicEnvironment) ---\n",
    "class DemographicEnvironment:\n",
    "    \"\"\"Multi-agent environment for demographic modeling.\"\"\"\n",
    "    def __init__(self, data, n_regions=8, max_steps=50):\n",
    "        self.data = data\n",
    "        # Obtain first n_regions unique region names\n",
    "        self.regions = data['region_name'].unique()[:n_regions]\n",
    "        self.n_regions = len(self.regions)\n",
    "        self.max_steps = max_steps\n",
    "        self.current_step = 0\n",
    "        # State and action space dimensions\n",
    "        self.state_dim = 8  # Core demographic indicators\n",
    "        self.action_dim = 4  # Policy actions\n",
    "        # Stability history for each region (used in enhanced metric)\n",
    "        self.stability_history = {region: deque(maxlen=20) for region in self.regions}\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        \"\"\"Reset environment to initial state.\"\"\"\n",
    "        self.current_step = 0\n",
    "        self.states = {}\n",
    "        self.histories = {region: [] for region in self.regions}\n",
    "        # Initialize state for all regions\n",
    "        for i, region in enumerate(self.regions):\n",
    "            region_data = self.data[self.data['region_name'] == region].iloc[0]\n",
    "            self.states[region] = self._normalize_state(region_data)\n",
    "        return self.get_observations()\n",
    "\n",
    "    def _normalize_state(self, region_data):\n",
    "        \"\"\"Normalize regional state variables.\"\"\"\n",
    "        return np.array([\n",
    "            region_data['birth_rate'] / 20.0,  # Normalize birth rate\n",
    "            region_data['death_rate'] / 30.0,  # Normalize death rate\n",
    "            region_data['natural_increase_rate'] / 10.0,  # Normalize natural increase\n",
    "            min(region_data['migration_balance'] / 100000.0, 1.0),  # Normalize migration balance\n",
    "            region_data['gdp_per_capita'] / 2000000.0,  # Normalize GDP per capita\n",
    "            region_data['unemployment_rate'] / 100.0,  # Normalize unemployment rate\n",
    "            region_data['population'] / 10000000.0,  # Normalize population\n",
    "            region_data['average_wage'] / 200000.0  # Normalize average wage\n",
    "        ])\n",
    "\n",
    "    def get_observations(self):\n",
    "        \"\"\"Obtain observations for all agents.\"\"\"\n",
    "        observations = {}\n",
    "        for region in self.regions:\n",
    "            # Local observation (current region state)\n",
    "            obs = self.states[region].copy()\n",
    "            # Aggregate states of other regions\n",
    "            other_states = [self.states[r] for r in self.regions if r != region]\n",
    "            # Append averaged information from other regions\n",
    "            if other_states:\n",
    "                avg_other = np.mean(other_states, axis=0)\n",
    "                obs = np.concatenate([obs, avg_other])\n",
    "            else:\n",
    "                obs = np.concatenate([obs, np.zeros(self.state_dim)])\n",
    "            observations[region] = obs\n",
    "        return observations\n",
    "\n",
    "    def step(self, actions):\n",
    "        \"\"\"Execute one step in the environment.\"\"\"\n",
    "        rewards = {}\n",
    "        for region in self.regions:\n",
    "            if region in actions:\n",
    "                action = actions[region]\n",
    "                old_state = self.states[region].copy()\n",
    "                # Update region state based on policy action\n",
    "                self.states[region] = self._apply_policy_action(self.states[region], action)\n",
    "                # Compute reward for state transition\n",
    "                rewards[region] = self._calculate_reward(old_state, self.states[region])\n",
    "                # Record historical trajectory\n",
    "                self.histories[region].append({\n",
    "                    'step': self.current_step,\n",
    "                    'state': old_state.copy(),\n",
    "                    'action': action.copy(),\n",
    "                    'reward': rewards[region]\n",
    "                })\n",
    "        self.current_step += 1\n",
    "        done = self.current_step >= self.max_steps\n",
    "        return self.get_observations(), rewards, done, {}\n",
    "\n",
    "    def _apply_policy_action(self, state, action):\n",
    "        \"\"\"Apply policy actions to the regional state.\"\"\"\n",
    "        new_state = state.copy()\n",
    "        # action[0] - maternal capital (affects birth rate)\n",
    "        # action[1] - healthcare investment (affects death rate)\n",
    "        # action[2] - migration policy (affects migration balance)\n",
    "        # action[3] - economic incentives (affects GDP and unemployment)\n",
    "        birth_rate_change = action[0] * 0.1\n",
    "        death_rate_change = -action[1] * 0.05\n",
    "        migration_change = action[2] * 0.05\n",
    "        economic_change = action[3] * 0.02\n",
    "        # Update indicators with constraints\n",
    "        new_state[0] = max(0, new_state[0] + birth_rate_change)  # birth rate\n",
    "        new_state[1] = max(0, new_state[1] + death_rate_change)  # death rate\n",
    "        new_state[2] = new_state[0] - new_state[1]  # natural increase\n",
    "        new_state[3] = new_state[3] + migration_change  # migration balance\n",
    "        new_state[4] = max(0, new_state[4] + economic_change)  # GDP per capita\n",
    "        new_state[5] = max(0, min(1, new_state[5] - economic_change * 0.5))  # unemployment\n",
    "        # Update population\n",
    "        population_change = (new_state[2] + new_state[3]) * 0.01\n",
    "        new_state[6] = max(0, new_state[6] + population_change)  # population\n",
    "        # Update average wage\n",
    "        new_state[7] = new_state[4] * 0.1 * (1 - new_state[5])  # average wage\n",
    "        return new_state\n",
    "\n",
    "    def _calculate_reward(self, old_state, new_state):\n",
    "        \"\"\"Compute reward based on improvement in demographic indicators.\"\"\"\n",
    "        # Positive changes with weights\n",
    "        birth_rate_improvement = (new_state[0] - old_state[0]) * 10\n",
    "        death_rate_improvement = (old_state[1] - new_state[1]) * 10\n",
    "        natural_increase_improvement = (new_state[2] - old_state[2]) * 15\n",
    "        migration_improvement = (new_state[3] - old_state[3]) * 5\n",
    "        gdp_improvement = (new_state[4] - old_state[4]) * 5\n",
    "        unemployment_improvement = (old_state[5] - new_state[5]) * 5\n",
    "        # Enhanced penalty for population instability\n",
    "        population_stability = -abs(new_state[6] - old_state[6]) * 5\n",
    "        # Sum components\n",
    "        total_reward = (birth_rate_improvement + death_rate_improvement +\n",
    "                       natural_increase_improvement + migration_improvement +\n",
    "                       gdp_improvement + unemployment_improvement + population_stability)\n",
    "        return total_reward"
   ],
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "# --- BLOCK 3: Definition of Neural Networks for Agents ---\n",
    "class ActorNetwork(nn.Module):\n",
    "    \"\"\"Actor network for MADDPG.\"\"\"\n",
    "    def __init__(self, input_dim, action_dim, hidden_dim=256):\n",
    "        super(ActorNetwork, self).__init__()\n",
    "        # Fully connected layers\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.fc3 = nn.Linear(hidden_dim, action_dim)\n",
    "        self.dropout = nn.Dropout(0.1)\n",
    "\n",
    "    def forward(self, state):\n",
    "        x = F.relu(self.fc1(state))\n",
    "        x = self.dropout(x)\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = torch.tanh(self.fc3(x))  # Actions bounded between -1 and 1\n",
    "        return x\n",
    "\n",
    "class CriticNetwork(nn.Module):\n",
    "    \"\"\"Critic network for MADDPG.\"\"\"\n",
    "    def __init__(self, state_dim, action_dim, n_agents, hidden_dim=256):\n",
    "        super(CriticNetwork, self).__init__()\n",
    "        # Total input dimension: concatenated states and actions of all agents\n",
    "        total_input_dim = state_dim * n_agents + action_dim * n_agents\n",
    "        self.fc1 = nn.Linear(total_input_dim, hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.fc3 = nn.Linear(hidden_dim, 1)\n",
    "        self.dropout = nn.Dropout(0.1)\n",
    "\n",
    "    def forward(self, states, actions):\n",
    "        x = torch.cat([states, actions], dim=1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.dropout(x)\n",
    "        x = F.relu(self.fc2(x))\n",
    "        return self.fc3(x)\n",
    "\n",
    "# --- NEW NETWORKS AND AGENTS FOR ADDITIONAL EXPERIMENTS ---\n",
    "class ActorNetworkPPO(nn.Module):\n",
    "    \"\"\"Actor network for MAPPO (outputting distribution parameters).\"\"\"\n",
    "    def __init__(self, input_dim, action_dim, hidden_dim=256):\n",
    "        super(ActorNetworkPPO, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.fc_mean = nn.Linear(hidden_dim, action_dim)\n",
    "        self.fc_log_std = nn.Parameter(torch.zeros(1, action_dim))  # Log standard deviation\n",
    "        self.dropout = nn.Dropout(0.1)\n",
    "\n",
    "    def forward(self, state):\n",
    "        x = F.relu(self.fc1(state))\n",
    "        x = self.dropout(x)\n",
    "        x = F.relu(self.fc2(x))\n",
    "        mean = torch.tanh(self.fc_mean(x))  # Mean bounded between -1 and 1\n",
    "        return mean\n",
    "\n",
    "    def get_action_log_prob(self, state):\n",
    "        mean = self.forward(state)\n",
    "        log_std = self.fc_log_std.expand_as(mean)\n",
    "        std = torch.exp(log_std)\n",
    "        dist = torch.distributions.Normal(mean, std)\n",
    "        action_raw = dist.rsample()  # Reparameterization trick\n",
    "        action = torch.tanh(action_raw)  # Clamp actions to [-1, 1]\n",
    "        # Correct density for tanh transformation\n",
    "        log_prob = dist.log_prob(action_raw) - torch.log(1 - action.pow(2) + 1e-6)\n",
    "        log_prob = log_prob.sum(dim=-1, keepdim=True)\n",
    "        return action, log_prob\n",
    "\n",
    "    def get_log_prob_entropy(self, state, action):\n",
    "        mean = self.forward(state)\n",
    "        log_std = self.fc_log_std.expand_as(mean)\n",
    "        std = torch.exp(log_std)\n",
    "        dist = torch.distributions.Normal(mean, std)\n",
    "        # Invert tanh to compute log_prob\n",
    "        u = torch.atanh(action.clamp(-0.999, 0.999))  # Avoid numerical inf\n",
    "        log_prob = dist.log_prob(u) - torch.log(1 - action.pow(2) + 1e-6)\n",
    "        log_prob = log_prob.sum(dim=-1, keepdim=True)\n",
    "        entropy = dist.entropy().sum(dim=-1, keepdim=True)\n",
    "        return log_prob, entropy\n",
    "\n",
    "class CriticNetworkPPO(nn.Module):\n",
    "    \"\"\"Critic network for MAPPO.\"\"\"\n",
    "    def __init__(self, input_dim, hidden_dim=256):\n",
    "        # input_dim should be env.state_dim * 2 * n_agents (global state)\n",
    "        super(CriticNetworkPPO, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.fc3 = nn.Linear(hidden_dim, 1)\n",
    "        self.dropout = nn.Dropout(0.1)\n",
    "\n",
    "    def forward(self, state):\n",
    "        x = F.relu(self.fc1(state))\n",
    "        x = self.dropout(x)\n",
    "        x = F.relu(self.fc2(x))\n",
    "        return self.fc3(x)\n",
    "\n",
    "class CriticNetworkTD3(nn.Module):\n",
    "    \"\"\"Critic network for MATD3 (dual Q-networks).\"\"\"\n",
    "    def __init__(self, state_dim, action_dim, n_agents, hidden_dim=256):\n",
    "        # state_dim here should be env.state_dim * 2\n",
    "        super(CriticNetworkTD3, self).__init__()\n",
    "        total_input_dim = state_dim * n_agents + action_dim * n_agents\n",
    "        # Q1\n",
    "        self.fc1_q1 = nn.Linear(total_input_dim, hidden_dim)\n",
    "        self.fc2_q1 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.fc3_q1 = nn.Linear(hidden_dim, 1)\n",
    "        # Q2\n",
    "        self.fc1_q2 = nn.Linear(total_input_dim, hidden_dim)\n",
    "        self.fc2_q2 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.fc3_q2 = nn.Linear(hidden_dim, 1)\n",
    "        self.dropout = nn.Dropout(0.1)\n",
    "\n",
    "    def forward(self, states, actions):\n",
    "        x = torch.cat([states, actions], dim=1)\n",
    "        # Q1 forward\n",
    "        q1 = F.relu(self.fc1_q1(x))\n",
    "        q1 = self.dropout(q1)\n",
    "        q1 = F.relu(self.fc2_q1(q1))\n",
    "        q1 = self.fc3_q1(q1)\n",
    "        # Q2 forward\n",
    "        q2 = F.relu(self.fc1_q2(x))\n",
    "        q2 = self.dropout(q2)\n",
    "        q2 = F.relu(self.fc2_q2(q2))\n",
    "        q2 = self.fc3_q2(q2)\n",
    "        return q1, q2\n",
    "\n",
    "    def Q1(self, states, actions):\n",
    "        x = torch.cat([states, actions], dim=1)\n",
    "        q1 = F.relu(self.fc1_q1(x))\n",
    "        q1 = self.dropout(q1)\n",
    "        q1 = F.relu(self.fc2_q1(q1))\n",
    "        q1 = self.fc3_q1(q1)\n",
    "        return q1\n",
    "\n",
    "class CriticNetworkAC(nn.Module):\n",
    "    \"\"\"Critic network for MAAC (simplified version without attention).\"\"\"\n",
    "    def __init__(self, state_dim, action_dim, n_agents, hidden_dim=256):\n",
    "        # state_dim here should be env.state_dim * 2\n",
    "        super(CriticNetworkAC, self).__init__()\n",
    "        # Simplified version uses same structure as MADDPG\n",
    "        total_input_dim = state_dim * n_agents + action_dim * n_agents\n",
    "        self.fc1 = nn.Linear(total_input_dim, hidden_dim)\n",
    "        self.fc2 = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.fc3 = nn.Linear(hidden_dim, 1)\n",
    "        self.dropout = nn.Dropout(0.1)\n",
    "\n",
    "    def forward(self, states, actions):\n",
    "        x = torch.cat([states, actions], dim=1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.dropout(x)\n",
    "        x = F.relu(self.fc2(x))\n",
    "        return self.fc3(x)"
   ],
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "# --- BLOCK 4: Definition of Agents (MADDPG, MAPPO, MATD3, MAAC) ---\n",
    "class MADDPGAgent:\n",
    "    \"\"\"MADDPG Agent.\"\"\"\n",
    "    def __init__(self, agent_id, state_dim, action_dim, n_agents, lr_actor=1e-4, lr_critic=1e-3):\n",
    "        self.agent_id = agent_id\n",
    "        self.state_dim = state_dim\n",
    "        self.action_dim = action_dim\n",
    "        self.n_agents = n_agents\n",
    "        # Actor networks (main and target)\n",
    "        self.actor = ActorNetwork(state_dim, action_dim)\n",
    "        self.actor_target = ActorNetwork(state_dim, action_dim)\n",
    "        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)\n",
    "        # Critic networks (main and target)\n",
    "        self.critic = CriticNetwork(state_dim, action_dim, n_agents)\n",
    "        self.critic_target = CriticNetwork(state_dim, action_dim, n_agents)\n",
    "        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)\n",
    "        # Initialize target networks by copying weights\n",
    "        self.hard_update(self.actor_target, self.actor)\n",
    "        self.hard_update(self.critic_target, self.critic)\n",
    "        # Training parameters\n",
    "        self.gamma = 0.95  # Discount factor\n",
    "        self.tau = 0.02    # Soft update parameter\n",
    "\n",
    "    def act(self, state, noise_scale=0.1):\n",
    "        \"\"\"Select action with exploration noise.\"\"\"\n",
    "        state = torch.FloatTensor(state).unsqueeze(0)\n",
    "        action = self.actor(state).squeeze(0).detach().numpy()\n",
    "        # Add noise for exploration\n",
    "        noise = np.random.normal(0, noise_scale, size=action.shape)\n",
    "        action = np.clip(action + noise, -1, 1)\n",
    "        return action\n",
    "\n",
    "    def hard_update(self, target, source):\n",
    "        \"\"\"Hard update of target network parameters.\"\"\"\n",
    "        for target_param, param in zip(target.parameters(), source.parameters()):\n",
    "            target_param.data.copy_(param.data)\n",
    "\n",
    "    def soft_update(self, target, source):\n",
    "        \"\"\"Soft update of target network parameters.\"\"\"\n",
    "        for target_param, param in zip(target.parameters(), source.parameters()):\n",
    "            target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau)\n",
    "\n",
    "class MAPPOAgent:\n",
    "    \"\"\"MAPPO Agent.\"\"\"\n",
    "    def __init__(self, agent_id, state_dim, action_dim, n_agents, lr_actor=1e-4, lr_critic=1e-3):\n",
    "        # state_dim here should be env.state_dim * 2 (local + global)\n",
    "        self.agent_id = agent_id\n",
    "        self.state_dim = state_dim  # 16\n",
    "        self.action_dim = action_dim\n",
    "        self.n_agents = n_agents\n",
    "        # Actor receives local + global state (16)\n",
    "        self.actor = ActorNetworkPPO(state_dim, action_dim)\n",
    "        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)\n",
    "        # Critic receives global state (16 * 8 = 128)\n",
    "        self.critic = CriticNetworkPPO(state_dim * n_agents)\n",
    "        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)\n",
    "        self.gamma = 0.95\n",
    "        self.eps_clip = 0.2  # PPO clipping parameter\n",
    "\n",
    "    def act(self, state):\n",
    "        \"\"\"Select action deterministically during inference.\"\"\"\n",
    "        # state already has dimension env.state_dim * 2\n",
    "        state = torch.FloatTensor(state).unsqueeze(0)\n",
    "        with torch.no_grad():\n",
    "            action, log_prob = self.actor.get_action_log_prob(state)\n",
    "        return action.squeeze(0).numpy(), log_prob.squeeze(0).numpy()\n",
    "\n",
    "class MATD3Agent:\n",
    "    \"\"\"MATD3 Agent.\"\"\"\n",
    "    def __init__(self, agent_id, state_dim, action_dim, n_agents, lr_actor=1e-4, lr_critic=1e-3):\n",
    "        # state_dim here should be env.state_dim * 2\n",
    "        self.agent_id = agent_id\n",
    "        self.state_dim = state_dim\n",
    "        self.action_dim = action_dim\n",
    "        self.n_agents = n_agents\n",
    "        self.train_step = 0\n",
    "        self.policy_delay = 2  # Policy update delay\n",
    "        self.actor = ActorNetwork(state_dim, action_dim)\n",
    "        self.actor_target = ActorNetwork(state_dim, action_dim)\n",
    "        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)\n",
    "        self.critic = CriticNetworkTD3(state_dim, action_dim, n_agents)\n",
    "        self.critic_target = CriticNetworkTD3(state_dim, action_dim, n_agents)\n",
    "        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)\n",
    "        self.hard_update(self.actor_target, self.actor)\n",
    "        self.hard_update(self.critic_target, self.critic)\n",
    "        self.gamma = 0.95\n",
    "        self.tau = 0.02\n",
    "\n",
    "    def act(self, state, noise_scale=0.1):\n",
    "        \"\"\"Select action with exploration noise.\"\"\"\n",
    "        # state already has dimension env.state_dim * 2\n",
    "        state = torch.FloatTensor(state).unsqueeze(0)\n",
    "        action = self.actor(state).squeeze(0).detach().numpy()\n",
    "        noise = np.random.normal(0, noise_scale, size=action.shape)\n",
    "        action = np.clip(action + noise, -1, 1)\n",
    "        return action\n",
    "\n",
    "    def hard_update(self, target, source):\n",
    "        for target_param, param in zip(target.parameters(), source.parameters()):\n",
    "            target_param.data.copy_(param.data)\n",
    "\n",
    "    def soft_update(self, target, source):\n",
    "        for target_param, param in zip(target.parameters(), source.parameters()):\n",
    "            target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau)\n",
    "\n",
    "class MAACAgent:\n",
    "    \"\"\"MAAC Agent (simplified version without attention).\"\"\"\n",
    "    def __init__(self, agent_id, state_dim, action_dim, n_agents, lr_actor=1e-4, lr_critic=1e-3):\n",
    "        # state_dim here should be env.state_dim * 2\n",
    "        self.agent_id = agent_id\n",
    "        self.state_dim = state_dim\n",
    "        self.action_dim = action_dim\n",
    "        self.n_agents = n_agents\n",
    "        self.actor = ActorNetwork(state_dim, action_dim)\n",
    "        self.actor_target = ActorNetwork(state_dim, action_dim)\n",
    "        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)\n",
    "        self.critic = CriticNetworkAC(state_dim, action_dim, n_agents)\n",
    "        self.critic_target = CriticNetworkAC(state_dim, action_dim, n_agents)\n",
    "        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)\n",
    "        self.hard_update(self.actor_target, self.actor)\n",
    "        self.hard_update(self.critic_target, self.critic)\n",
    "        self.gamma = 0.95\n",
    "        self.tau = 0.02\n",
    "\n",
    "    def act(self, state, noise_scale=0.1):\n",
    "        \"\"\"Select action with exploration noise.\"\"\"\n",
    "        # state already has dimension env.state_dim * 2\n",
    "        state = torch.FloatTensor(state).unsqueeze(0)\n",
    "        action = self.actor(state).squeeze(0).detach().numpy()\n",
    "        noise = np.random.normal(0, noise_scale, size=action.shape)\n",
    "        action = np.clip(action + noise, -1, 1)\n",
    "        return action\n",
    "\n",
    "    def hard_update(self, target, source):\n",
    "        for target_param, param in zip(target.parameters(), source.parameters()):\n",
    "            target_param.data.copy_(param.data)\n",
    "\n",
    "    def soft_update(self, target, source):\n",
    "        for target_param, param in zip(target.parameters(), source.parameters()):\n",
    "            target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau)"
   ],
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "# --- BLOCK 5: Evolutionary Booster (EvolutionaryBooster) ---\n",
    "class EvolutionaryBooster:\n",
    "    \"\"\"Evolutionary booster for optimizing MARL agents.\"\"\"\n",
    "    def __init__(self, population_size=20, mutation_rate=0.1, crossover_rate=0.7):\n",
    "        self.population_size = population_size\n",
    "        self.mutation_rate = mutation_rate\n",
    "        self.crossover_rate = crossover_rate\n",
    "        self.population = []\n",
    "        self.fitness_history = []\n",
    "\n",
    "    def initialize_population(self, agent_template):\n",
    "        \"\"\"Initialize population of agents.\"\"\"\n",
    "        self.population = []\n",
    "        for _ in range(self.population_size):\n",
    "            agent_copy = self._copy_agent(agent_template)\n",
    "            self._mutate_agent(agent_copy, mutation_strength=0.3)\n",
    "            self.population.append(agent_copy)\n",
    "\n",
    "    def _copy_agent(self, agent):\n",
    "        \"\"\"Create a deep copy of an agent.\"\"\"\n",
    "        # Determine agent type and instantiate corresponding class\n",
    "        if isinstance(agent, MADDPGAgent):\n",
    "            new_agent = MADDPGAgent(agent.agent_id, agent.state_dim, agent.action_dim, agent.n_agents)\n",
    "        elif isinstance(agent, MAPPOAgent):\n",
    "            new_agent = MAPPOAgent(agent.agent_id, agent.state_dim, agent.action_dim, agent.n_agents)\n",
    "        elif isinstance(agent, MATD3Agent):\n",
    "            new_agent = MATD3Agent(agent.agent_id, agent.state_dim, agent.action_dim, agent.n_agents)\n",
    "        elif isinstance(agent, MAACAgent):\n",
    "            new_agent = MAACAgent(agent.agent_id, agent.state_dim, agent.action_dim, agent.n_agents)\n",
    "        else:\n",
    "            raise ValueError(f\"Unsupported agent type: {type(agent)}\")\n",
    "        # Copy network weights\n",
    "        if hasattr(agent, 'actor'):\n",
    "            new_agent.actor.load_state_dict(agent.actor.state_dict())\n",
    "        if hasattr(agent, 'critic'):\n",
    "            new_agent.critic.load_state_dict(agent.critic.state_dict())\n",
    "        if hasattr(agent, 'actor_target'):\n",
    "            new_agent.actor_target.load_state_dict(agent.actor_target.state_dict())\n",
    "        if hasattr(agent, 'critic_target'):\n",
    "            new_agent.critic_target.load_state_dict(agent.critic_target.state_dict())\n",
    "        return new_agent\n",
    "\n",
    "    def _mutate_agent(self, agent, mutation_strength=0.1):\n",
    "        \"\"\"Mutate agent parameters.\"\"\"\n",
    "        # Mutate actor network parameters\n",
    "        if hasattr(agent, 'actor'):\n",
    "            for param in agent.actor.parameters():\n",
    "                if np.random.random() < self.mutation_rate:\n",
    "                    noise = torch.randn_like(param) * mutation_strength\n",
    "                    param.data += noise\n",
    "        # Mutate critic network parameters\n",
    "        if hasattr(agent, 'critic'):\n",
    "            for param in agent.critic.parameters():\n",
    "                if np.random.random() < self.mutation_rate:\n",
    "                    noise = torch.randn_like(param) * mutation_strength\n",
    "                    param.data += noise\n",
    "\n",
    "    def _crossover_agents(self, parent1, parent2):\n",
    "        \"\"\"Crossover two agents.\"\"\"\n",
    "        child1 = self._copy_agent(parent1)\n",
    "        child2 = self._copy_agent(parent2)\n",
    "        # Crossover actor parameters\n",
    "        if hasattr(parent1, 'actor') and hasattr(parent2, 'actor'):\n",
    "            for p1, p2, c1, c2 in zip(parent1.actor.parameters(), parent2.actor.parameters(),\n",
    "                                     child1.actor.parameters(), child2.actor.parameters()):\n",
    "                if np.random.random() < self.crossover_rate:\n",
    "                    mask = torch.rand_like(p1) > 0.5\n",
    "                    c1.data = torch.where(mask, p1.data, p2.data)\n",
    "                    c2.data = torch.where(mask, p2.data, p1.data)\n",
    "        # Crossover critic parameters\n",
    "        if hasattr(parent1, 'critic') and hasattr(parent2, 'critic'):\n",
    "            for p1, p2, c1, c2 in zip(parent1.critic.parameters(), parent2.critic.parameters(),\n",
    "                                     child1.critic.parameters(), child2.critic.parameters()):\n",
    "                if np.random.random() < self.crossover_rate:\n",
    "                    mask = torch.rand_like(p1) > 0.5\n",
    "                    c1.data = torch.where(mask, p1.data, p2.data)\n",
    "                    c2.data = torch.where(mask, p2.data, p1.data)\n",
    "        return child1, child2\n",
    "\n",
    "    def evolve_population(self, fitness_scores):\n",
    "        \"\"\"Evolve population based on fitness scores.\"\"\"\n",
    "        # Sort agents by fitness\n",
    "        population_fitness = list(zip(self.population, fitness_scores))\n",
    "        population_fitness.sort(key=lambda x: x[1], reverse=True)\n",
    "        # Select elite individuals\n",
    "        elite_size = self.population_size // 4\n",
    "        elite = [agent for agent, _ in population_fitness[:elite_size]]\n",
    "        # Generate new population\n",
    "        new_population = elite.copy()\n",
    "        while len(new_population) < self.population_size:\n",
    "            parent1 = self._tournament_selection(population_fitness)\n",
    "            parent2 = self._tournament_selection(population_fitness)\n",
    "            if np.random.random() < self.crossover_rate:\n",
    "                child1, child2 = self._crossover_agents(parent1, parent2)\n",
    "                self._mutate_agent(child1)\n",
    "                self._mutate_agent(child2)\n",
    "                new_population.extend([child1, child2])\n",
    "            else:\n",
    "                child = self._copy_agent(parent1)\n",
    "                self._mutate_agent(child)\n",
    "                new_population.append(child)\n",
    "        # Trim and update population\n",
    "        self.population = new_population[:self.population_size]\n",
    "        # Record fitness history\n",
    "        self.fitness_history.append(max(fitness_scores))\n",
    "        return self.population[0]  # Return best agent\n",
    "\n",
    "    def _tournament_selection(self, population_fitness, tournament_size=3):\n",
    "        \"\"\"Tournament selection of parent.\"\"\"\n",
    "        tournament = random.sample(population_fitness, min(tournament_size, len(population_fitness)))\n",
    "        return max(tournament, key=lambda x: x[1])[0]"
   ],
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "# --- BLOCK 6: Experiment Logger (ExperimentLogger) ---\n",
    "class ExperimentLogger:\n",
    "    \"\"\"Logger for recording and visualizing experimental results.\"\"\"\n",
    "    def __init__(self):\n",
    "        self.metrics = defaultdict(list)\n",
    "        self.episode_rewards = defaultdict(list)\n",
    "        self.evolution_history = []\n",
    "\n",
    "    def log_episode(self, episode, agent_rewards, avg_reward, stability_metric):\n",
    "        \"\"\"Log episode-level metrics.\"\"\"\n",
    "        self.metrics['episode'].append(episode)\n",
    "        self.metrics['avg_reward'].append(avg_reward)\n",
    "        self.metrics['stability'].append(stability_metric)\n",
    "        # Store individual agent rewards\n",
    "        for agent_id, reward in agent_rewards.items():\n",
    "            self.episode_rewards[agent_id].append(reward)\n",
    "\n",
    "    def log_evolution(self, generation, best_fitness, avg_fitness):\n",
    "        \"\"\"Log evolutionary progress.\"\"\"\n",
    "        self.evolution_history.append({\n",
    "            'generation': generation,\n",
    "            'best_fitness': best_fitness,\n",
    "            'avg_fitness': avg_fitness\n",
    "        })\n",
    "\n",
    "    def plot_results(self):\n",
    "        \"\"\"Visualize experimental results.\"\"\"\n",
    "        plt.figure(figsize=(15, 10))\n",
    "        # Plot 1: Average reward over episodes\n",
    "        plt.subplot(2, 3, 1)\n",
    "        plt.plot(self.metrics['episode'], self.metrics['avg_reward'])\n",
    "        plt.title('Average Reward Over Episodes')\n",
    "        plt.xlabel('Episode')\n",
    "        plt.ylabel('Average Reward')\n",
    "        plt.grid(True)\n",
    "        # Plot 2: Stability metric\n",
    "        plt.subplot(2, 3, 2)\n",
    "        plt.plot(self.metrics['episode'], self.metrics['stability'])\n",
    "        plt.title('System Stability Metric')\n",
    "        plt.xlabel('Episode')\n",
    "        plt.ylabel('Stability')\n",
    "        plt.grid(True)\n",
    "        # Plot 3: Agent rewards (first 50 episodes)\n",
    "        plt.subplot(2, 3, 3)\n",
    "        for agent_id, rewards in self.episode_rewards.items():\n",
    "            plt.plot(rewards[:50], label=f'Agent {agent_id}', alpha=0.7)\n",
    "        plt.title('Agent Rewards (First 50 Episodes)')\n",
    "        plt.xlabel('Episode')\n",
    "        plt.ylabel('Reward')\n",
    "        plt.legend()\n",
    "        plt.grid(True)\n",
    "        # Plot 4: Population evolution (if available)\n",
    "        if self.evolution_history:\n",
    "            plt.subplot(2, 3, 4)\n",
    "            generations = [x['generation'] for x in self.evolution_history]\n",
    "            best_fitness = [x['best_fitness'] for x in self.evolution_history]\n",
    "            avg_fitness = [x['avg_fitness'] for x in self.evolution_history]\n",
    "            plt.plot(generations, best_fitness, label='Best Fitness', linewidth=2)\n",
    "            plt.plot(generations, avg_fitness, label='Average Fitness', alpha=0.7)\n",
    "            plt.title('Population Evolution')\n",
    "            plt.xlabel('Generation')\n",
    "            plt.ylabel('Fitness')\n",
    "            plt.legend()\n",
    "            plt.grid(True)\n",
    "        # Plot 5: Reward distribution\n",
    "        plt.subplot(2, 3, 5)\n",
    "        all_rewards = []\n",
    "        for rewards in self.episode_rewards.values():\n",
    "            all_rewards.extend(rewards)\n",
    "        plt.hist(all_rewards, bins=30, alpha=0.7, edgecolor='black')\n",
    "        plt.title('Reward Distribution')\n",
    "        plt.xlabel('Reward')\n",
    "        plt.ylabel('Frequency')\n",
    "        plt.grid(True)\n",
    "        # Plot 6: Moving average of rewards\n",
    "        plt.subplot(2, 3, 6)\n",
    "        window_size = 10\n",
    "        if len(self.metrics['avg_reward']) >= window_size:\n",
    "            moving_avg = pd.Series(self.metrics['avg_reward']).rolling(window=window_size).mean()\n",
    "            plt.plot(self.metrics['episode'], moving_avg)\n",
    "            plt.title(f'Moving Average Reward (Window Size: {window_size})')\n",
    "            plt.xlabel('Episode')\n",
    "            plt.ylabel('Moving Average')\n",
    "            plt.grid(True)\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "\n",
    "    def save_results(self, filename):\n",
    "        \"\"\"Save results to CSV.\"\"\"\n",
    "        results_df = pd.DataFrame(self.metrics)\n",
    "        results_df.to_csv(filename, index=False)\n",
    "        print(f\"Results saved to {filename}\")"
   ],
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "# --- BLOCK 7: Auxiliary Training Functions (_train_agent_*) ---\n",
    "def _train_agent_maddpg(agent, replay_buffer, batch_size=32):\n",
    "    \"\"\"Train one MADDPG agent (identical to original _train_agent).\"\"\"\n",
    "    if len(replay_buffer) < batch_size:\n",
    "        return\n",
    "    batch = random.sample(replay_buffer, batch_size)\n",
    "    states = torch.FloatTensor([x[0] for x in batch])\n",
    "    actions = torch.FloatTensor([x[1] for x in batch])\n",
    "    rewards = torch.FloatTensor([x[2] for x in batch]).unsqueeze(1)\n",
    "    next_states = torch.FloatTensor([x[3] for x in batch])\n",
    "    dones = torch.BoolTensor([x[4] for x in batch]).unsqueeze(1)\n",
    "    # Update critic\n",
    "    with torch.no_grad():\n",
    "        next_actions = agent.actor_target(next_states)\n",
    "        global_next_states = next_states.repeat(1, agent.n_agents).view(batch_size, -1)\n",
    "        global_next_actions = next_actions.repeat(1, agent.n_agents).view(batch_size, -1)\n",
    "        target_q = agent.critic_target(global_next_states, global_next_actions)\n",
    "        target_q = rewards + agent.gamma * target_q * (~dones)\n",
    "    global_states = states.repeat(1, agent.n_agents).view(batch_size, -1)\n",
    "    global_actions = actions.repeat(1, agent.n_agents).view(batch_size, -1)\n",
    "    current_q = agent.critic(global_states, global_actions)\n",
    "    critic_loss = F.mse_loss(current_q, target_q)\n",
    "    agent.critic_optimizer.zero_grad()\n",
    "    critic_loss.backward()\n",
    "    torch.nn.utils.clip_grad_norm_(agent.critic.parameters(), 0.5)\n",
    "    agent.critic_optimizer.step()\n",
    "    # Update actor\n",
    "    predicted_actions = agent.actor(states)\n",
    "    global_predicted_actions = predicted_actions.repeat(1, agent.n_agents).view(batch_size, -1)\n",
    "    actor_loss = -agent.critic(global_states, global_predicted_actions).mean()\n",
    "    agent.actor_optimizer.zero_grad()\n",
    "    actor_loss.backward()\n",
    "    torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), 0.5)\n",
    "    agent.actor_optimizer.step()\n",
    "    # Soft update target networks\n",
    "    agent.soft_update(agent.actor_target, agent.actor)\n",
    "    agent.soft_update(agent.critic_target, agent.critic)\n",
    "\n",
    "def _train_agent_td3(agent, replay_buffer, batch_size=32, policy_noise=0.2, noise_clip=0.5):\n",
    "    \"\"\"Train one MATD3 agent (TD3 variant).\"\"\"\n",
    "    if len(replay_buffer) < batch_size:\n",
    "        return\n",
    "    batch = random.sample(replay_buffer, batch_size)\n",
    "    states = torch.FloatTensor([x[0] for x in batch])\n",
    "    actions = torch.FloatTensor([x[1] for x in batch])\n",
    "    rewards = torch.FloatTensor([x[2] for x in batch]).unsqueeze(1)\n",
    "    next_states = torch.FloatTensor([x[3] for x in batch])\n",
    "    dones = torch.BoolTensor([x[4] for x in batch]).unsqueeze(1)\n",
    "    # Update critic (dual Q-networks)\n",
    "    with torch.no_grad():\n",
    "        # Add noise to next actions\n",
    "        noise = (torch.randn_like(actions) * policy_noise).clamp(-noise_clip, noise_clip)\n",
    "        next_actions = (agent.actor_target(next_states) + noise).clamp(-1, 1)\n",
    "        global_next_states = next_states.repeat(1, agent.n_agents).view(batch_size, -1)\n",
    "        global_next_actions = next_actions.repeat(1, agent.n_agents).view(batch_size, -1)\n",
    "        target_q1, target_q2 = agent.critic_target(global_next_states, global_next_actions)\n",
    "        target_q = torch.min(target_q1, target_q2)\n",
    "        target_q = rewards + agent.gamma * target_q * (~dones)\n",
    "    global_states = states.repeat(1, agent.n_agents).view(batch_size, -1)\n",
    "    global_actions = actions.repeat(1, agent.n_agents).view(batch_size, -1)\n",
    "    current_q1, current_q2 = agent.critic(global_states, global_actions)\n",
    "    critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)\n",
    "    agent.critic_optimizer.zero_grad()\n",
    "    critic_loss.backward()\n",
    "    torch.nn.utils.clip_grad_norm_(agent.critic.parameters(), 0.5)\n",
    "    agent.critic_optimizer.step()\n",
    "    # Delayed actor update\n",
    "    if agent.train_step % agent.policy_delay == 0:\n",
    "        predicted_actions = agent.actor(states)\n",
    "        global_predicted_actions = predicted_actions.repeat(1, agent.n_agents).view(batch_size, -1)\n",
    "        actor_loss = -agent.critic.Q1(global_states, global_predicted_actions).mean()  # Use only Q1\n",
    "        agent.actor_optimizer.zero_grad()\n",
    "        actor_loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), 0.5)\n",
    "        agent.actor_optimizer.step()\n",
    "        # Soft update target networks\n",
    "        agent.soft_update(agent.actor_target, agent.actor)\n",
    "        agent.soft_update(agent.critic_target, agent.critic)\n",
    "    agent.train_step += 1\n",
    "\n",
    "def _train_agent_ppo(agent, global_states, actions, old_log_probs, rewards, next_global_states, dones, batch_size=32, epochs=4):\n",
    "    \"\"\"Train one MAPPO agent (PPO variant).\"\"\"\n",
    "    # Convert to tensors\n",
    "    # global_states: list of vectors of size (state_dim * n_agents)\n",
    "    global_states = torch.FloatTensor(global_states)\n",
    "    actions = torch.FloatTensor(actions)\n",
    "    old_log_probs = torch.FloatTensor(old_log_probs)\n",
    "    rewards = torch.FloatTensor(rewards).unsqueeze(1)\n",
    "    next_global_states = torch.FloatTensor(next_global_states)\n",
    "    dones = torch.BoolTensor(dones).unsqueeze(1)\n",
    "    # Compute advantage (simple TD error without GAE)\n",
    "    with torch.no_grad():\n",
    "        values = agent.critic(global_states)\n",
    "        next_values = agent.critic(next_global_states)\n",
    "        td_target = rewards + agent.gamma * next_values * (~dones)\n",
    "        td_error = td_target - values\n",
    "        advantage = td_error\n",
    "    # Multiple PPO epochs\n",
    "    dataset_size = global_states.shape[0]\n",
    "    indices = np.arange(dataset_size)\n",
    "    for _ in range(epochs):\n",
    "        np.random.shuffle(indices)\n",
    "        for start_idx in range(0, dataset_size, batch_size):\n",
    "            end_idx = min(start_idx + batch_size, dataset_size)\n",
    "            batch_indices = indices[start_idx:end_idx]\n",
    "            batch_global_states = global_states[batch_indices]\n",
    "            batch_actions = actions[batch_indices]\n",
    "            batch_old_log_probs = old_log_probs[batch_indices]\n",
    "            batch_advantage = advantage[batch_indices]\n",
    "            # Update critic\n",
    "            batch_td_target = td_target[batch_indices]\n",
    "            value = agent.critic(batch_global_states)\n",
    "            critic_loss = F.mse_loss(value, batch_td_target)\n",
    "            agent.critic_optimizer.zero_grad()\n",
    "            critic_loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(agent.critic.parameters(), 0.5)\n",
    "            agent.critic_optimizer.step()\n",
    "            # Update actor\n",
    "            # For MAPPO, actor expects local+global state (first state_dim*2 elements)\n",
    "            local_global_states_for_actor = batch_global_states[:, :agent.state_dim]\n",
    "            new_log_probs, entropy = agent.actor.get_log_prob_entropy(local_global_states_for_actor, batch_actions)\n",
    "            ratio = torch.exp(new_log_probs - batch_old_log_probs)\n",
    "            surr1 = ratio * batch_advantage\n",
    "            surr2 = torch.clamp(ratio, 1 - agent.eps_clip, 1 + agent.eps_clip) * batch_advantage\n",
    "            actor_loss = -torch.min(surr1, surr2).mean() - 0.01 * entropy.mean()  # Entropy coefficient\n",
    "            agent.actor_optimizer.zero_grad()\n",
    "            actor_loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(agent.actor.parameters(), 0.5)\n",
    "            agent.actor_optimizer.step()"
   ],
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "# --- BLOCK 8: Auxiliary Evaluation and Metric Functions ---\n",
    "def _evaluate_agent_fitness(agent, env, region, n_eval_episodes=3):\n",
    "    \"\"\"Evaluate agent fitness.\"\"\"\n",
    "    total_reward = 0\n",
    "    for _ in range(n_eval_episodes):\n",
    "        observations = env.reset()  # observations: dict {region: obs}\n",
    "        episode_reward = 0\n",
    "        for _ in range(env.max_steps):\n",
    "            if isinstance(agent, MAPPOAgent):\n",
    "                # Use deterministic actions for evaluation\n",
    "                action, _ = agent.act(observations[region])\n",
    "                action = np.clip(action, -1, 1)\n",
    "            else:\n",
    "                # Use zero noise for evaluation\n",
    "                action = agent.act(observations[region], noise_scale=0.0)\n",
    "            actions = {region: action}\n",
    "            # Random actions for other agents\n",
    "            for other_region in env.regions:\n",
    "                if other_region != region:\n",
    "                    actions[other_region] = np.random.uniform(-1, 1, env.action_dim)\n",
    "            next_observations, rewards, done, _ = env.step(actions)\n",
    "            episode_reward += rewards.get(region, 0)\n",
    "            observations = next_observations\n",
    "            if done:\n",
    "                break\n",
    "        total_reward += episode_reward\n",
    "    return total_reward / n_eval_episodes\n",
    "\n",
    "def _calculate_stability_metric(env):\n",
    "    \"\"\"Compute enhanced stability metric for demographic system.\"\"\"\n",
    "    stability_scores = []\n",
    "    for region in env.regions:\n",
    "        state = env.states[region]\n",
    "        # Compute stability as average of four components:\n",
    "        # 1. Natural balance (penalize negative growth)\n",
    "        natural_balance = max(0, -state[2]) * 2\n",
    "        # 2. Economic stability (more sensitive)\n",
    "        economic_stability = 1 - (state[5] ** 2) * 1.5\n",
    "        # 3. Migration stability (stricter)\n",
    "        migration_stability = 1 / (1 + abs(state[3]) * 20)\n",
    "        # 4. Population stability (new component)\n",
    "        population_stability = 1 - min(1, abs(state[6] - 1.0))  # Assume normalized pop ~1.0\n",
    "        # Current regional stability\n",
    "        current_stability = (natural_balance + economic_stability +\n",
    "                           migration_stability + population_stability) / 4\n",
    "        # Store in history\n",
    "        env.stability_history[region].append(current_stability)\n",
    "        # Long-term stability (lower variance = higher stability)\n",
    "        if len(env.stability_history[region]) > 5:\n",
    "            long_term_stability = 1 - np.std(env.stability_history[region]) * 2\n",
    "        else:\n",
    "            long_term_stability = 1.0  # Insufficient data\n",
    "        # Final composite metric\n",
    "        region_stability = 0.4 * current_stability + 0.6 * long_term_stability\n",
    "        stability_scores.append(region_stability)\n",
    "    return np.mean(stability_scores)"
   ],
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "# --- BLOCK 9: Main Training Functions (Experiments) ---\n",
    "# --- Original MADDPG with Evolutionary Booster (from source file) ---\n",
    "# EXPERIMENT 0: MADDPG-EVO-0 (original, unchanged)\n",
    "def train_maddpg_with_evolution_original(env, n_episodes=600, evolution_frequency=50,\n",
    "                               population_size=8, save_results=True, experiment_name=\"MADDPG-EVO-0\"):\n",
    "    \"\"\"Main training function for MADDPG with evolutionary boosting (original version).\"\"\"\n",
    "    print(f\"=== Starting Experiment: {experiment_name} ===\")\n",
    "    # Initialize agents for each region\n",
    "    agents = {}\n",
    "    for i, region in enumerate(env.regions):\n",
    "        agents[region] = MADDPGAgent(\n",
    "            agent_id=i,\n",
    "            state_dim=env.state_dim * 2,  # local + global state\n",
    "            action_dim=env.action_dim,\n",
    "            n_agents=env.n_regions\n",
    "        )\n",
    "    # Initialize evolutionary boosters for each region\n",
    "    evolution_boosters = {}\n",
    "    for region in env.regions:\n",
    "        booster = EvolutionaryBooster(population_size=population_size)\n",
    "        booster.initialize_population(agents[region])\n",
    "        evolution_boosters[region] = booster\n",
    "    # Create replay buffers\n",
    "    replay_buffers = {region: deque(maxlen=10000) for region in env.regions}\n",
    "    # Create logger\n",
    "    logger = ExperimentLogger()\n",
    "    # Main training loop\n",
    "    for episode in range(n_episodes):\n",
    "        observations = env.reset()  # observations: dict {region: obs}\n",
    "        episode_rewards = {region: 0 for region in env.regions}\n",
    "        # Step loop within episode\n",
    "        for step in range(env.max_steps):\n",
    "            # Select actions for all agents\n",
    "            actions = {}\n",
    "            for region in env.regions:\n",
    "                action = agents[region].act(observations[region], noise_scale=0.1)\n",
    "                actions[region] = action\n",
    "            # Execute step in environment\n",
    "            next_observations, rewards, done, _ = env.step(actions)\n",
    "            # Store experience in buffers\n",
    "            for region in env.regions:\n",
    "                replay_buffers[region].append((\n",
    "                    observations[region], actions[region], rewards[region],\n",
    "                    next_observations[region], done\n",
    "                ))\n",
    "                episode_rewards[region] += rewards[region]\n",
    "            observations = next_observations\n",
    "            if done:\n",
    "                break\n",
    "        # Train agents every 10 episodes (starting from episode 10)\n",
    "        if episode % 10 == 0 and episode > 0:\n",
    "            for region in env.regions:\n",
    "                if len(replay_buffers[region]) > 100:\n",
    "                    _train_agent_maddpg(agents[region], replay_buffers[region], batch_size=32)\n",
    "        # Evolutionary optimization (at specified frequency, starting after first iteration)\n",
    "        if episode % evolution_frequency == 0 and episode > 0:\n",
    "            print(f\"{experiment_name} - Evolutionary optimization at episode {episode}...\")\n",
    "            for region in env.regions:\n",
    "                # Evaluate fitness of entire population\n",
    "                fitness_scores = []\n",
    "                for agent in evolution_boosters[region].population:\n",
    "                    fitness = _evaluate_agent_fitness(agent, env, region)\n",
    "                    fitness_scores.append(fitness)\n",
    "                # Evolve population and retrieve best agent\n",
    "                best_agent = evolution_boosters[region].evolve_population(fitness_scores)\n",
    "                agents[region] = best_agent\n",
    "                # Log evolution results\n",
    "                logger.log_evolution(\n",
    "                    episode // evolution_frequency,\n",
    "                    max(fitness_scores),\n",
    "                    np.mean(fitness_scores)\n",
    "                )\n",
    "        # Log episode results\n",
    "        avg_reward = np.mean(list(episode_rewards.values()))\n",
    "        stability_metric = _calculate_stability_metric(env)\n",
    "        logger.log_episode(episode, episode_rewards, avg_reward, stability_metric)\n",
    "        # Print progress every 20 episodes\n",
    "        if episode % 20 == 0:\n",
    "            print(f\"{experiment_name} - Episode {episode}, Avg Reward: {avg_reward:.3f}, \"\n",
    "                  f\"Stability: {stability_metric:.3f}\")\n",
    "    print(f\"{experiment_name} - Training completed!\")\n",
    "    # Visualize and save results\n",
    "    logger.plot_results()\n",
    "    if save_results:\n",
    "        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "        logger.save_results(f\"{experiment_name.lower()}_results_{timestamp}.csv\")\n",
    "    return agents, logger\n",
    "\n",
    "# --- NEW TRAINING FUNCTIONS FOR ADDITIONAL EXPERIMENTS ---\n",
    "# EXPERIMENT 1: MADDPG (Improved)\n",
    "def train_maddpg(env, n_episodes=600, save_results=True, experiment_name=\"MADDPG\"):\n",
    "    \"\"\"Train MADDPG with enhanced stability.\"\"\"\n",
    "    print(f\"=== Starting Experiment: {experiment_name} ===\")\n",
    "    agents = {}\n",
    "    for i, region in enumerate(env.regions):\n",
    "        agents[region] = MADDPGAgent(\n",
    "            agent_id=i,\n",
    "            state_dim=env.state_dim * 2,\n",
    "            action_dim=env.action_dim,\n",
    "            n_agents=env.n_regions,\n",
    "            lr_actor=5e-5,  # Reduced learning rate for improved stability\n",
    "            lr_critic=1e-4\n",
    "        )\n",
    "    replay_buffers = {region: deque(maxlen=10000) for region in env.regions}\n",
    "    logger = ExperimentLogger()\n",
    "    for episode in range(n_episodes):\n",
    "        observations = env.reset()\n",
    "        episode_rewards = {region: 0 for region in env.regions}\n",
    "        # Adaptive noise for refined exploration\n",
    "        noise_scale = max(0.05, 0.2 * (1 - episode / n_episodes))\n",
    "        for step in range(env.max_steps):\n",
    "            actions = {}\n",
    "            for region in env.regions:\n",
    "                action = agents[region].act(observations[region], noise_scale=noise_scale)\n",
    "                actions[region] = action\n",
    "            next_observations, rewards, done, _ = env.step(actions)\n",
    "            for region in env.regions:\n",
    "                replay_buffers[region].append((\n",
    "                    observations[region], actions[region], rewards[region],\n",
    "                    next_observations[region], done\n",
    "                ))\n",
    "                episode_rewards[region] += rewards[region]\n",
    "            observations = next_observations\n",
    "            if done:\n",
    "                break\n",
    "        # Train more frequently with larger buffer threshold\n",
    "        if episode % 5 == 0 and episode > 0:\n",
    "            for region in env.regions:\n",
    "                if len(replay_buffers[region]) > 200:  # Increased buffer threshold\n",
    "                    _train_agent_maddpg(agents[region], replay_buffers[region], batch_size=64)  # Larger batch\n",
    "        avg_reward = np.mean(list(episode_rewards.values()))\n",
    "        stability_metric = _calculate_stability_metric(env)\n",
    "        logger.log_episode(episode, episode_rewards, avg_reward, stability_metric)\n",
    "        if episode % 20 == 0:\n",
    "            print(f\"{experiment_name} - Episode {episode}, Avg Reward: {avg_reward:.3f}, Stability: {stability_metric:.3f}\")\n",
    "    print(f\"{experiment_name} - Training completed!\")\n",
    "    if save_results:\n",
    "        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "        logger.save_results(f\"{experiment_name.lower()}_results_{timestamp}.csv\")\n",
    "    return agents, logger\n",
    "\n",
    "# EXPERIMENT 2: MAPPO (Improved)\n",
    "def train_mappo(env, n_episodes=600, save_results=True, experiment_name=\"MAPPO\"):\n",
    "    \"\"\"Train MAPPO with enhanced stability.\"\"\"\n",
    "    print(f\"=== Starting Experiment: {experiment_name} ===\")\n",
    "    agents = {}\n",
    "    for i, region in enumerate(env.regions):\n",
    "        agents[region] = MAPPOAgent(\n",
    "            agent_id=i,\n",
    "            state_dim=env.state_dim * 2,  # 16\n",
    "            action_dim=env.action_dim,\n",
    "            n_agents=env.n_regions,\n",
    "            lr_actor=5e-5,  # Reduced learning rate\n",
    "            lr_critic=1e-4\n",
    "        )\n",
    "    logger = ExperimentLogger()\n",
    "    for episode in range(n_episodes):\n",
    "        observations = env.reset()\n",
    "        episode_rewards = {region: 0 for region in env.regions}\n",
    "        # Buffers for PPO\n",
    "        global_states_buffer = {region: [] for region in env.regions}\n",
    "        actions_buffer = {region: [] for region in env.regions}\n",
    "        old_log_probs_buffer = {region: [] for region in env.regions}\n",
    "        rewards_buffer = {region: [] for region in env.regions}\n",
    "        next_global_states_buffer = {region: [] for region in env.regions}\n",
    "        dones_buffer = {region: [] for region in env.regions}\n",
    "        for step in range(env.max_steps):\n",
    "            actions = {}\n",
    "            log_probs = {}\n",
    "            for region in env.regions:\n",
    "                # MAPPO agent expects full observation (local + global)\n",
    "                action, log_prob = agents[region].act(observations[region])\n",
    "                actions[region] = action\n",
    "                log_probs[region] = log_prob\n",
    "            next_observations, rewards, done, _ = env.step(actions)\n",
    "            # Gather global state for critic\n",
    "            global_obs_list = [observations[r] for r in env.regions]\n",
    "            global_state = np.concatenate(global_obs_list)  # (128,)\n",
    "            next_global_obs_list = [next_observations[r] for r in env.regions]\n",
    "            next_global_state = np.concatenate(next_global_obs_list)  # (128,)\n",
    "            for region in env.regions:\n",
    "                # Store for PPO training\n",
    "                global_states_buffer[region].append(global_state)\n",
    "                actions_buffer[region].append(actions[region])\n",
    "                old_log_probs_buffer[region].append(log_probs[region])\n",
    "                rewards_buffer[region].append(rewards[region])\n",
    "                next_global_states_buffer[region].append(next_global_state)\n",
    "                dones_buffer[region].append(done)\n",
    "                episode_rewards[region] += rewards[region]\n",
    "            observations = next_observations\n",
    "            if done:\n",
    "                break\n",
    "        # Train PPO after episode with improved hyperparameters\n",
    "        for region in env.regions:\n",
    "            if len(global_states_buffer[region]) > 0:\n",
    "                _train_agent_ppo(\n",
    "                    agents[region],\n",
    "                    global_states_buffer[region],\n",
    "                    actions_buffer[region],\n",
    "                    old_log_probs_buffer[region],\n",
    "                    rewards_buffer[region],\n",
    "                    next_global_states_buffer[region],\n",
    "                    dones_buffer[region],\n",
    "                    batch_size=64,  # Increased batch size\n",
    "                    epochs=8  # Increased epochs for thorough training\n",
    "                )\n",
    "        avg_reward = np.mean(list(episode_rewards.values()))\n",
    "        stability_metric = _calculate_stability_metric(env)\n",
    "        logger.log_episode(episode, episode_rewards, avg_reward, stability_metric)\n",
    "        if episode % 20 == 0:\n",
    "            print(f\"{experiment_name} - Episode {episode}, Avg Reward: {avg_reward:.3f}, Stability: {stability_metric:.3f}\")\n",
    "    print(f\"{experiment_name} - Training completed!\")\n",
    "    if save_results:\n",
    "        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "        logger.save_results(f\"{experiment_name.lower()}_results_{timestamp}.csv\")\n",
    "    return agents, logger\n",
    "\n",
    "# EXPERIMENT 3: MATD3 (Improved)\n",
    "def train_matd3(env, n_episodes=600, save_results=True, experiment_name=\"MATD3\"):\n",
    "    \"\"\"Train MATD3 with enhanced stability.\"\"\"\n",
    "    print(f\"=== Starting Experiment: {experiment_name} ===\")\n",
    "    agents = {}\n",
    "    for i, region in enumerate(env.regions):\n",
    "        agents[region] = MATD3Agent(\n",
    "            agent_id=i,\n",
    "            state_dim=env.state_dim * 2,\n",
    "            action_dim=env.action_dim,\n",
    "            n_agents=env.n_regions,\n",
    "            lr_actor=5e-5,  # Reduced learning rate\n",
    "            lr_critic=1e-4\n",
    "        )\n",
    "    replay_buffers = {region: deque(maxlen=10000) for region in env.regions}\n",
    "    logger = ExperimentLogger()\n",
    "    for episode in range(n_episodes):\n",
    "        observations = env.reset()\n",
    "        episode_rewards = {region: 0 for region in env.regions}\n",
    "        # Adaptive noise for refined exploration\n",
    "        noise_scale = max(0.05, 0.3 * (1 - episode / n_episodes))\n",
    "        policy_noise = max(0.1, 0.2 * (1 - episode / n_episodes))\n",
    "        for step in range(env.max_steps):\n",
    "            actions = {}\n",
    "            for region in env.regions:\n",
    "                action = agents[region].act(observations[region], noise_scale=noise_scale)\n",
    "                actions[region] = action\n",
    "            next_observations, rewards, done, _ = env.step(actions)\n",
    "            for region in env.regions:\n",
    "                replay_buffers[region].append((\n",
    "                    observations[region], actions[region], rewards[region],\n",
    "                    next_observations[region], done\n",
    "                ))\n",
    "                episode_rewards[region] += rewards[region]\n",
    "            observations = next_observations\n",
    "            if done:\n",
    "                break\n",
    "        # Train more frequently with larger buffer threshold\n",
    "        if episode % 5 == 0 and episode > 0:\n",
    "            for region in env.regions:\n",
    "                if len(replay_buffers[region]) > 200:\n",
    "                    _train_agent_td3(agents[region], replay_buffers[region], batch_size=64,\n",
    "                                   policy_noise=policy_noise, noise_clip=policy_noise * 2)\n",
    "        avg_reward = np.mean(list(episode_rewards.values()))\n",
    "        stability_metric = _calculate_stability_metric(env)\n",
    "        logger.log_episode(episode, episode_rewards, avg_reward, stability_metric)\n",
    "        if episode % 20 == 0:\n",
    "            print(f\"{experiment_name} - Episode {episode}, Avg Reward: {avg_reward:.3f}, Stability: {stability_metric:.3f}\")\n",
    "    print(f\"{experiment_name} - Training completed!\")\n",
    "    if save_results:\n",
    "        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "        logger.save_results(f\"{experiment_name.lower()}_results_{timestamp}.csv\")\n",
    "    return agents, logger\n",
    "\n",
    "# EXPERIMENT 4: MAAC (Improved)\n",
    "def train_maac(env, n_episodes=600, save_results=True, experiment_name=\"MAAC\"):\n",
    "    \"\"\"Train MAAC with enhanced stability.\"\"\"\n",
    "    print(f\"=== Starting Experiment: {experiment_name} ===\")\n",
    "    agents = {}\n",
    "    for i, region in enumerate(env.regions):\n",
    "        agents[region] = MAACAgent(\n",
    "            agent_id=i,\n",
    "            state_dim=env.state_dim * 2,\n",
    "            action_dim=env.action_dim,\n",
    "            n_agents=env.n_regions,\n",
    "            lr_actor=5e-5,  # Reduced learning rate\n",
    "            lr_critic=1e-4\n",
    "        )\n",
    "    replay_buffers = {region: deque(maxlen=10000) for region in env.regions}\n",
    "    logger = ExperimentLogger()\n",
    "    for episode in range(n_episodes):\n",
    "        observations = env.reset()\n",
    "        episode_rewards = {region: 0 for region in env.regions}\n",
    "        # Adaptive noise for refined exploration\n",
    "        noise_scale = max(0.05, 0.2 * (1 - episode / n_episodes))\n",
    "        for step in range(env.max_steps):\n",
    "            actions = {}\n",
    "            for region in env.regions:\n",
    "                action = agents[region].act(observations[region], noise_scale=noise_scale)\n",
    "                actions[region] = action\n",
    "            next_observations, rewards, done, _ = env.step(actions)\n",
    "            for region in env.regions:\n",
    "                replay_buffers[region].append((\n",
    "                    observations[region], actions[region], rewards[region],\n",
    "                    next_observations[region], done\n",
    "                ))\n",
    "                episode_rewards[region] += rewards[region]\n",
    "            observations = next_observations\n",
    "            if done:\n",
    "                break\n",
    "        # Reuse MADDPG training with improved parameters\n",
    "        if episode % 5 == 0 and episode > 0:\n",
    "            for region in env.regions:\n",
    "                if len(replay_buffers[region]) > 200:\n",
    "                    _train_agent_maddpg(agents[region], replay_buffers[region], batch_size=64)  # Reuse\n",
    "        avg_reward = np.mean(list(episode_rewards.values()))\n",
    "        stability_metric = _calculate_stability_metric(env)\n",
    "        logger.log_episode(episode, episode_rewards, avg_reward, stability_metric)\n",
    "        if episode % 20 == 0:\n",
    "            print(f\"{experiment_name} - Episode {episode}, Avg Reward: {avg_reward:.3f}, Stability: {stability_metric:.3f}\")\n",
    "    print(f\"{experiment_name} - Training completed!\")\n",
    "    if save_results:\n",
    "        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "        logger.save_results(f\"{experiment_name.lower()}_results_{timestamp}.csv\")\n",
    "    return agents, logger\n",
    "\n",
    "# EXPERIMENT 5: MADDPG-EVO (Improved)\n",
    "def train_maddpg_with_evolution(env, n_episodes=600, min_improvement=0.008,  # Reduced threshold for earlier trigger\n",
    "                               population_size=8, save_results=True, experiment_name=\"MADDPG-EVO\"):\n",
    "    \"\"\"Train MADDPG with adaptive evolutionary optimization.\"\"\"\n",
    "    print(f\"=== Starting Experiment: {experiment_name} ===\")\n",
    "    print(f\"Training over {n_episodes} episodes. Adaptive evolution triggered upon stagnation (min_improvement={min_improvement}).\")\n",
    "    # Initialize agents\n",
    "    agents = {}\n",
    "    for i, region in enumerate(env.regions):\n",
    "        agents[region] = MADDPGAgent(\n",
    "            agent_id=i,\n",
    "            state_dim=env.state_dim * 2,  # local + global state\n",
    "            action_dim=env.action_dim,\n",
    "            n_agents=env.n_regions,\n",
    "            lr_actor=5e-5,  # Reduced learning rate\n",
    "            lr_critic=1e-4\n",
    "        )\n",
    "    # Initialize evolutionary boosters\n",
    "    evolution_boosters = {}\n",
    "    for region in env.regions:\n",
    "        booster = EvolutionaryBooster(population_size=population_size, mutation_rate=0.05)  # Reduced mutation\n",
    "        booster.initialize_population(agents[region])\n",
    "        evolution_boosters[region] = booster\n",
    "    # Create replay buffers\n",
    "    replay_buffers = {region: deque(maxlen=10000) for region in env.regions}\n",
    "    # Create logger\n",
    "    logger = ExperimentLogger()\n",
    "    print(f\"Starting training over {n_episodes} episodes...\")\n",
    "    # Track progress for adaptive evolution\n",
    "    reward_history = []\n",
    "    last_evolution_episode = 0\n",
    "    min_episodes_between_evolutions = 20  # Increased interval\n",
    "    # Main training loop\n",
    "    for episode in range(n_episodes):\n",
    "        observations = env.reset()\n",
    "        episode_rewards = {region: 0 for region in env.regions}\n",
    "        # Adaptive noise\n",
    "        noise_scale = max(0.05, 0.2 * (1 - episode / n_episodes))\n",
    "        # Step loop\n",
    "        for step in range(env.max_steps):\n",
    "            actions = {}\n",
    "            for region in env.regions:\n",
    "                action = agents[region].act(observations[region], noise_scale=noise_scale)\n",
    "                actions[region] = action\n",
    "            next_observations, rewards, done, _ = env.step(actions)\n",
    "            for region in env.regions:\n",
    "                replay_buffers[region].append((\n",
    "                    observations[region], actions[region], rewards[region],\n",
    "                    next_observations[region], done\n",
    "                ))\n",
    "                episode_rewards[region] += rewards[region]\n",
    "            observations = next_observations\n",
    "            if done:\n",
    "                break\n",
    "        # Train agents\n",
    "        if episode % 10 == 0 and episode > 0:\n",
    "            for region in env.regions:\n",
    "                if len(replay_buffers[region]) > 200:\n",
    "                    _train_agent_maddpg(agents[region], replay_buffers[region], batch_size=64)\n",
    "        # Log episode\n",
    "        avg_reward = np.mean(list(episode_rewards.values()))\n",
    "        stability_metric = _calculate_stability_metric(env)\n",
    "        logger.log_episode(episode, episode_rewards, avg_reward, stability_metric)\n",
    "        # Track reward history\n",
    "        reward_history.append(avg_reward)\n",
    "        # Adaptive evolution: trigger if recent improvement below threshold\n",
    "        if len(reward_history) > 30 and episode - last_evolution_episode >= min_episodes_between_evolutions:\n",
    "            recent_rewards = reward_history[-30:]\n",
    "            improvement_rate = (recent_rewards[-1] - recent_rewards[0]) / 30\n",
    "            if improvement_rate < min_improvement:\n",
    "                print(f\"{experiment_name} - Adaptive evolutionary optimization at episode {episode} (stagnation detected)\")\n",
    "                for region in env.regions:\n",
    "                    fitness_scores = []\n",
    "                    for agent in evolution_boosters[region].population:\n",
    "                        fitness = _evaluate_agent_fitness(agent, env, region)\n",
    "                        fitness_scores.append(fitness)\n",
    "                    best_agent = evolution_boosters[region].evolve_population(fitness_scores)\n",
    "                    agents[region] = best_agent\n",
    "                    logger.log_evolution(\n",
    "                        episode,\n",
    "                        max(fitness_scores),\n",
    "                        np.mean(fitness_scores)\n",
    "                    )\n",
    "                last_evolution_episode = episode\n",
    "        # Print progress\n",
    "        if episode % 20 == 0:\n",
    "            print(f\"{experiment_name} - Episode {episode}, Avg Reward: {avg_reward:.3f}, \"\n",
    "                  f\"Stability: {stability_metric:.3f}\")\n",
    "    print(f\"{experiment_name} - Training completed!\")\n",
    "    logger.plot_results()\n",
    "    if save_results:\n",
    "        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "        logger.save_results(f\"{experiment_name.lower()}_results_{timestamp}.csv\")\n",
    "    return agents, logger\n",
    "\n",
    "# EXPERIMENT 6: MAPPO-EVO (Improved)\n",
    "def train_mappo_with_evolution(env, n_episodes=600, min_improvement=0.008,  # Reduced threshold\n",
    "                               population_size=8, save_results=True, experiment_name=\"MAPPO-EVO\"):\n",
    "    \"\"\"Train MAPPO with adaptive evolutionary optimization.\"\"\"\n",
    "    print(f\"=== Starting Experiment: {experiment_name} ===\")\n",
    "    print(f\"Training over {n_episodes} episodes. Adaptive evolution triggered upon stagnation (min_improvement={min_improvement}).\")\n",
    "    agents = {}\n",
    "    for i, region in enumerate(env.regions):\n",
    "        agents[region] = MAPPOAgent(\n",
    "            agent_id=i,\n",
    "            state_dim=env.state_dim * 2,\n",
    "            action_dim=env.action_dim,\n",
    "            n_agents=env.n_regions,\n",
    "            lr_actor=5e-5,  # Reduced learning rate\n",
    "            lr_critic=1e-4\n",
    "        )\n",
    "    evolution_boosters = {}\n",
    "    for region in env.regions:\n",
    "        booster = EvolutionaryBooster(population_size=population_size, mutation_rate=0.05)\n",
    "        booster.initialize_population(agents[region])\n",
    "        evolution_boosters[region] = booster\n",
    "    logger = ExperimentLogger()\n",
    "    # Track progress\n",
    "    reward_history = []\n",
    "    last_evolution_episode = 0\n",
    "    min_episodes_between_evolutions = 25  # Increased interval\n",
    "    for episode in range(n_episodes):\n",
    "        observations = env.reset()\n",
    "        episode_rewards = {region: 0 for region in env.regions}\n",
    "        global_states_buffer = {region: [] for region in env.regions}\n",
    "        actions_buffer = {region: [] for region in env.regions}\n",
    "        old_log_probs_buffer = {region: [] for region in env.regions}\n",
    "        rewards_buffer = {region: [] for region in env.regions}\n",
    "        next_global_states_buffer = {region: [] for region in env.regions}\n",
    "        dones_buffer = {region: [] for region in env.regions}\n",
    "        for step in range(env.max_steps):\n",
    "            actions = {}\n",
    "            log_probs = {}\n",
    "            for region in env.regions:\n",
    "                action, log_prob = agents[region].act(observations[region])\n",
    "                actions[region] = action\n",
    "                log_probs[region] = log_prob\n",
    "            next_observations, rewards, done, _ = env.step(actions)\n",
    "            global_obs_list = [observations[r] for r in env.regions]\n",
    "            global_state = np.concatenate(global_obs_list)\n",
    "            next_global_obs_list = [next_observations[r] for r in env.regions]\n",
    "            next_global_state = np.concatenate(next_global_obs_list)\n",
    "            for region in env.regions:\n",
    "                global_states_buffer[region].append(global_state)\n",
    "                actions_buffer[region].append(actions[region])\n",
    "                old_log_probs_buffer[region].append(log_probs[region])\n",
    "                rewards_buffer[region].append(rewards[region])\n",
    "                next_global_states_buffer[region].append(next_global_state)\n",
    "                dones_buffer[region].append(done)\n",
    "                episode_rewards[region] += rewards[region]\n",
    "            observations = next_observations\n",
    "            if done:\n",
    "                break\n",
    "        for region in env.regions:\n",
    "            if len(global_states_buffer[region]) > 0:\n",
    "                _train_agent_ppo(\n",
    "                    agents[region],\n",
    "                    global_states_buffer[region],\n",
    "                    actions_buffer[region],\n",
    "                    old_log_probs_buffer[region],\n",
    "                    rewards_buffer[region],\n",
    "                    next_global_states_buffer[region],\n",
    "                    dones_buffer[region],\n",
    "                    batch_size=64,  # Increased batch size\n",
    "                    epochs=8  # Increased epochs\n",
    "                )\n",
    "        # Log episode\n",
    "        avg_reward = np.mean(list(episode_rewards.values()))\n",
    "        stability_metric = _calculate_stability_metric(env)\n",
    "        logger.log_episode(episode, episode_rewards, avg_reward, stability_metric)\n",
    "        # Track reward history\n",
    "        reward_history.append(avg_reward)\n",
    "        # Adaptive evolution\n",
    "        if len(reward_history) > 35 and episode - last_evolution_episode >= min_episodes_between_evolutions:\n",
    "            recent_rewards = reward_history[-35:]\n",
    "            improvement_rate = (recent_rewards[-1] - recent_rewards[0]) / 35\n",
    "            if improvement_rate < min_improvement:\n",
    "                print(f\"{experiment_name} - Adaptive evolutionary optimization at episode {episode} (stagnation detected)\")\n",
    "                for region in env.regions:\n",
    "                    fitness_scores = []\n",
    "                    for agent in evolution_boosters[region].population:\n",
    "                        fitness = _evaluate_agent_fitness(agent, env, region)\n",
    "                        fitness_scores.append(fitness)\n",
    "                    best_agent = evolution_boosters[region].evolve_population(fitness_scores)\n",
    "                    agents[region] = best_agent\n",
    "                    logger.log_evolution(\n",
    "                        episode,\n",
    "                        max(fitness_scores),\n",
    "                        np.mean(fitness_scores)\n",
    "                    )\n",
    "                last_evolution_episode = episode\n",
    "        if episode % 20 == 0:\n",
    "            print(f\"{experiment_name} - Episode {episode}, Avg Reward: {avg_reward:.3f}, Stability: {stability_metric:.3f}\")\n",
    "    print(f\"{experiment_name} - Training completed!\")\n",
    "    logger.plot_results()\n",
    "    if save_results:\n",
    "        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "        logger.save_results(f\"{experiment_name.lower()}_results_{timestamp}.csv\")\n",
    "    return agents, logger\n",
    "\n",
    "# EXPERIMENT 7: MATD3-EVO (Improved)\n",
    "def train_matd3_with_evolution(env, n_episodes=600, min_improvement=0.008,  # Reduced threshold\n",
    "                               population_size=8, save_results=True, experiment_name=\"MATD3-EVO\"):\n",
    "    \"\"\"Train MATD3 with adaptive evolutionary optimization.\"\"\"\n",
    "    print(f\"=== Starting Experiment: {experiment_name} ===\")\n",
    "    print(f\"Training over {n_episodes} episodes. Adaptive evolution triggered upon stagnation (min_improvement={min_improvement}).\")\n",
    "    agents = {}\n",
    "    for i, region in enumerate(env.regions):\n",
    "        agents[region] = MATD3Agent(\n",
    "            agent_id=i,\n",
    "            state_dim=env.state_dim * 2,\n",
    "            action_dim=env.action_dim,\n",
    "            n_agents=env.n_regions,\n",
    "            lr_actor=5e-5,  # Reduced learning rate\n",
    "            lr_critic=1e-4\n",
    "        )\n",
    "    evolution_boosters = {}\n",
    "    for region in env.regions:\n",
    "        booster = EvolutionaryBooster(population_size=population_size, mutation_rate=0.05)\n",
    "        booster.initialize_population(agents[region])\n",
    "        evolution_boosters[region] = booster\n",
    "    replay_buffers = {region: deque(maxlen=10000) for region in env.regions}\n",
    "    logger = ExperimentLogger()\n",
    "    # Track progress\n",
    "    reward_history = []\n",
    "    last_evolution_episode = 0\n",
    "    min_episodes_between_evolutions = 15  # Increased interval\n",
    "    for episode in range(n_episodes):\n",
    "        observations = env.reset()\n",
    "        episode_rewards = {region: 0 for region in env.regions}\n",
    "        # Adaptive noise\n",
    "        noise_scale = max(0.05, 0.3 * (1 - episode / n_episodes))\n",
    "        policy_noise = max(0.1, 0.2 * (1 - episode / n_episodes))\n",
    "        for step in range(env.max_steps):\n",
    "            actions = {}\n",
    "            for region in env.regions:\n",
    "                action = agents[region].act(observations[region], noise_scale=noise_scale)\n",
    "                actions[region] = action\n",
    "            next_observations, rewards, done, _ = env.step(actions)\n",
    "            for region in env.regions:\n",
    "                replay_buffers[region].append((\n",
    "                    observations[region], actions[region], rewards[region],\n",
    "                    next_observations[region], done\n",
    "                ))\n",
    "                episode_rewards[region] += rewards[region]\n",
    "            observations = next_observations\n",
    "            if done:\n",
    "                break\n",
    "        # Train agents\n",
    "        if episode % 10 == 0 and episode > 0:\n",
    "            for region in env.regions:\n",
    "                if len(replay_buffers[region]) > 200:\n",
    "                    _train_agent_td3(agents[region], replay_buffers[region], batch_size=64,\n",
    "                                   policy_noise=policy_noise, noise_clip=policy_noise * 2)\n",
    "        # Log episode\n",
    "        avg_reward = np.mean(list(episode_rewards.values()))\n",
    "        stability_metric = _calculate_stability_metric(env)\n",
    "        logger.log_episode(episode, episode_rewards, avg_reward, stability_metric)\n",
    "        # Track reward history\n",
    "        reward_history.append(avg_reward)\n",
    "        # Adaptive evolution\n",
    "        if len(reward_history) > 25 and episode - last_evolution_episode >= min_episodes_between_evolutions:\n",
    "            recent_rewards = reward_history[-25:]\n",
    "            improvement_rate = (recent_rewards[-1] - recent_rewards[0]) / 25\n",
    "            if improvement_rate < min_improvement:\n",
    "                print(f\"{experiment_name} - Adaptive evolutionary optimization at episode {episode} (stagnation detected)\")\n",
    "                for region in env.regions:\n",
    "                    fitness_scores = []\n",
    "                    for agent in evolution_boosters[region].population:\n",
    "                        fitness = _evaluate_agent_fitness(agent, env, region)\n",
    "                        fitness_scores.append(fitness)\n",
    "                    best_agent = evolution_boosters[region].evolve_population(fitness_scores)\n",
    "                    agents[region] = best_agent\n",
    "                    logger.log_evolution(\n",
    "                        episode,\n",
    "                        max(fitness_scores),\n",
    "                        np.mean(fitness_scores)\n",
    "                    )\n",
    "                last_evolution_episode = episode\n",
    "        if episode % 20 == 0:\n",
    "            print(f\"{experiment_name} - Episode {episode}, Avg Reward: {avg_reward:.3f}, Stability: {stability_metric:.3f}\")\n",
    "    print(f\"{experiment_name} - Training completed!\")\n",
    "    logger.plot_results()\n",
    "    if save_results:\n",
    "        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "        logger.save_results(f\"{experiment_name.lower()}_results_{timestamp}.csv\")\n",
    "    return agents, logger\n",
    "\n",
    "# EXPERIMENT 8: MAAC-EVO (Improved)\n",
    "def train_maac_with_evolution(env, n_episodes=600, min_improvement=0.008,  # Reduced threshold\n",
    "                              population_size=8, save_results=True, experiment_name=\"MAAC-EVO\"):\n",
    "    \"\"\"Train MAAC with adaptive evolutionary optimization.\"\"\"\n",
    "    print(f\"=== Starting Experiment: {experiment_name} ===\")\n",
    "    print(f\"Training over {n_episodes} episodes. Adaptive evolution triggered upon stagnation (min_improvement={min_improvement}).\")\n",
    "    agents = {}\n",
    "    for i, region in enumerate(env.regions):\n",
    "        agents[region] = MAACAgent(\n",
    "            agent_id=i,\n",
    "            state_dim=env.state_dim * 2,\n",
    "            action_dim=env.action_dim,\n",
    "            n_agents=env.n_regions,\n",
    "            lr_actor=5e-5,  # Reduced learning rate\n",
    "            lr_critic=1e-4\n",
    "        )\n",
    "    evolution_boosters = {}\n",
    "    for region in env.regions:\n",
    "        booster = EvolutionaryBooster(population_size=population_size, mutation_rate=0.05)\n",
    "        booster.initialize_population(agents[region])\n",
    "        evolution_boosters[region] = booster\n",
    "    replay_buffers = {region: deque(maxlen=10000) for region in env.regions}\n",
    "    logger = ExperimentLogger()\n",
    "    # Track progress\n",
    "    reward_history = []\n",
    "    last_evolution_episode = 0\n",
    "    min_episodes_between_evolutions = 20  # Increased interval\n",
    "    for episode in range(n_episodes):\n",
    "        observations = env.reset()\n",
    "        episode_rewards = {region: 0 for region in env.regions}\n",
    "        # Adaptive noise\n",
    "        noise_scale = max(0.05, 0.2 * (1 - episode / n_episodes))\n",
    "        for step in range(env.max_steps):\n",
    "            actions = {}\n",
    "            for region in env.regions:\n",
    "                action = agents[region].act(observations[region], noise_scale=noise_scale)\n",
    "                actions[region] = action\n",
    "            next_observations, rewards, done, _ = env.step(actions)\n",
    "            for region in env.regions:\n",
    "                replay_buffers[region].append((\n",
    "                    observations[region], actions[region], rewards[region],\n",
    "                    next_observations[region], done\n",
    "                ))\n",
    "                episode_rewards[region] += rewards[region]\n",
    "            observations = next_observations\n",
    "            if done:\n",
    "                break\n",
    "        # Train agents (reuse MADDPG training)\n",
    "        if episode % 10 == 0 and episode > 0:\n",
    "            for region in env.regions:\n",
    "                if len(replay_buffers[region]) > 200:\n",
    "                    _train_agent_maddpg(agents[region], replay_buffers[region], batch_size=64)\n",
    "        # Log episode\n",
    "        avg_reward = np.mean(list(episode_rewards.values()))\n",
    "        stability_metric = _calculate_stability_metric(env)\n",
    "        logger.log_episode(episode, episode_rewards, avg_reward, stability_metric)\n",
    "        # Track reward history\n",
    "        reward_history.append(avg_reward)\n",
    "        # Adaptive evolution\n",
    "        if len(reward_history) > 30 and episode - last_evolution_episode >= min_episodes_between_evolutions:\n",
    "            recent_rewards = reward_history[-30:]\n",
    "            improvement_rate = (recent_rewards[-1] - recent_rewards[0]) / 30\n",
    "            if improvement_rate < min_improvement:\n",
    "                print(f\"{experiment_name} - Adaptive evolutionary optimization at episode {episode} (stagnation detected)\")\n",
    "                for region in env.regions:\n",
    "                    fitness_scores = []\n",
    "                    for agent in evolution_boosters[region].population:\n",
    "                        fitness = _evaluate_agent_fitness(agent, env, region)\n",
    "                        fitness_scores.append(fitness)\n",
    "                    best_agent = evolution_boosters[region].evolve_population(fitness_scores)\n",
    "                    agents[region] = best_agent\n",
    "                    logger.log_evolution(\n",
    "                        episode,\n",
    "                        max(fitness_scores),\n",
    "                        np.mean(fitness_scores)\n",
    "                    )\n",
    "                last_evolution_episode = episode\n",
    "        if episode % 20 == 0:\n",
    "            print(f\"{experiment_name} - Episode {episode}, Avg Reward: {avg_reward:.3f}, Stability: {stability_metric:.3f}\")\n",
    "    print(f\"{experiment_name} - Training completed!\")\n",
    "    logger.plot_results()\n",
    "    if save_results:\n",
    "        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "        logger.save_results(f\"{experiment_name.lower()}_results_{timestamp}.csv\")\n",
    "    return agents, logger"
   ],
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "# --- BLOCK 10: Main Execution Block ---\n",
    "if __name__ == \"__main__\":\n",
    "    print(\"=== STARTING ADDITIONAL EXPERIMENTS ===\")\n",
    "    # --- Data and Environment Preparation ---\n",
    "    print(\"Loading and processing real-world data...\")\n",
    "    data_processor = DemographicDataProcessor(\n",
    "        'regions_data_selective.csv',  # path to real data\n",
    "        'crisis.txt'  # path to crisis scenarios\n",
    "    )\n",
    "    # CHANGED: Use data from 2000 onward\n",
    "    years = list(range(2000, 2025))\n",
    "    target_regions = ['Moscow', 'St. Petersburg', 'Tatarstan', 'Krasnodar Krai',\n",
    "                     'Sverdlovsk Oblast', 'Novosibirsk Oblast', 'Samara Oblast', 'Rostov Oblast']\n",
    "    training_data = data_processor.generate_training_data(years, target_regions, apply_crisis=True)\n",
    "    print(f\"Processed {len(training_data)} records for {len(target_regions)} regions\")\n",
    "    # --- Dictionary to store results of all experiments ---\n",
    "    experiment_results = {}\n",
    "    # --- Run Experiments ---\n",
    "    # 0. MADDPG-EVO-0 (original, unchanged)\n",
    "    print(\"--- EXPERIMENT 0: MADDPG-EVO-0 ---\")\n",
    "    env = DemographicEnvironment(training_data, n_regions=8, max_steps=50)\n",
    "    agents_maddpg_evo_0, logger_maddpg_evo_0 = train_maddpg_with_evolution_original(env, n_episodes=600, evolution_frequency=50, population_size=8, experiment_name=\"MADDPG-EVO-0\")\n",
    "    experiment_results[\"MADDPG-EVO-0\"] = logger_maddpg_evo_0\n",
    "    # 1. MADDPG (improved)\n",
    "    print(\"--- EXPERIMENT 1: MADDPG ---\")\n",
    "    env.reset()\n",
    "    agents_maddpg, logger_maddpg = train_maddpg(env, n_episodes=600, experiment_name=\"MADDPG\")\n",
    "    experiment_results[\"MADDPG\"] = logger_maddpg\n",
    "    # 2. MAPPO (improved)\n",
    "    print(\"--- EXPERIMENT 2: MAPPO ---\")\n",
    "    env.reset()\n",
    "    agents_mappo, logger_mappo = train_mappo(env, n_episodes=600, experiment_name=\"MAPPO\")\n",
    "    experiment_results[\"MAPPO\"] = logger_mappo\n",
    "    # 3. MATD3 (improved)\n",
    "    print(\"--- EXPERIMENT 3: MATD3 ---\")\n",
    "    env.reset()\n",
    "    agents_matd3, logger_matd3 = train_matd3(env, n_episodes=600, experiment_name=\"MATD3\")\n",
    "    experiment_results[\"MATD3\"] = logger_matd3\n",
    "    # 4. MAAC (improved)\n",
    "    print(\"--- EXPERIMENT 4: MAAC ---\")\n",
    "    env.reset()\n",
    "    agents_maac, logger_maac = train_maac(env, n_episodes=600, experiment_name=\"MAAC\")\n",
    "    experiment_results[\"MAAC\"] = logger_maac\n",
    "    # 5. MADDPG-EVO (improved)\n",
    "    print(\"--- EXPERIMENT 5: MADDPG-EVO ---\")\n",
    "    env.reset()\n",
    "    agents_maddpg_evo, logger_maddpg_evo = train_maddpg_with_evolution(env, n_episodes=600, min_improvement=0.008, population_size=8, experiment_name=\"MADDPG-EVO\")\n",
    "    experiment_results[\"MADDPG-EVO\"] = logger_maddpg_evo\n",
    "    # 6. MAPPO-EVO (improved)\n",
    "    print(\"--- EXPERIMENT 6: MAPPO-EVO ---\")\n",
    "    env.reset()\n",
    "    agents_mappo_evo, logger_mappo_evo = train_mappo_with_evolution(env, n_episodes=600, min_improvement=0.008, population_size=8, experiment_name=\"MAPPO-EVO\")\n",
    "    experiment_results[\"MAPPO-EVO\"] = logger_mappo_evo\n",
    "    # 7. MATD3-EVO (improved)\n",
    "    print(\"--- EXPERIMENT 7: MATD3-EVO ---\")\n",
    "    env.reset()\n",
    "    agents_matd3_evo, logger_matd3_evo = train_matd3_with_evolution(env, n_episodes=600, min_improvement=0.008, population_size=8, experiment_name=\"MATD3-EVO\")\n",
    "    experiment_results[\"MATD3-EVO\"] = logger_matd3_evo\n",
    "    # 8. MAAC-EVO (improved)\n",
    "    print(\"--- EXPERIMENT 8: MAAC-EVO ---\")\n",
    "    env.reset()\n",
    "    agents_maac_evo, logger_maac_evo = train_maac_with_evolution(env, n_episodes=600, min_improvement=0.008, population_size=8, experiment_name=\"MAAC-EVO\")\n",
    "    experiment_results[\"MAAC-EVO\"] = logger_maac_evo\n",
    "    print(\"=== ALL EXPERIMENTS COMPLETED ===\")\n",
    "    # --- Comparative Analysis ---\n",
    "    print(\"=== GENERATING COMPARATIVE RESULTS ===\")\n",
    "    # 1. Comparative plot: Average reward\n",
    "    plt.figure(figsize=(15, 10))\n",
    "    for name, logger in experiment_results.items():\n",
    "        if 'episode' in logger.metrics and 'avg_reward' in logger.metrics:\n",
    "            episodes = logger.metrics['episode']\n",
    "            avg_rewards = logger.metrics['avg_reward']\n",
    "            # Smooth\n",
    "            if len(avg_rewards) > 10:\n",
    "                moving_avg = pd.Series(avg_rewards).rolling(window=10, min_periods=1).mean()\n",
    "                plt.plot(episodes, moving_avg, label=name, linewidth=2)\n",
    "            else:\n",
    "                plt.plot(episodes, avg_rewards, label=name, linewidth=2, marker='o', markersize=3)\n",
    "    plt.title('Comparison of Average Reward Across Algorithms (Moving Average, Window=10)')\n",
    "    plt.xlabel('Episode')\n",
    "    plt.ylabel('Average Reward')\n",
    "    plt.legend()\n",
    "    plt.grid(True)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    # 2. Comparative plot: Stability metric\n",
    "    plt.figure(figsize=(15, 10))\n",
    "    for name, logger in experiment_results.items():\n",
    "        if 'episode' in logger.metrics and 'stability' in logger.metrics:\n",
    "            episodes = logger.metrics['episode']\n",
    "            stability = logger.metrics['stability']\n",
    "            if len(stability) > 10:\n",
    "                moving_avg = pd.Series(stability).rolling(window=10, min_periods=1).mean()\n",
    "                plt.plot(episodes, moving_avg, label=name, linewidth=2)\n",
    "            else:\n",
    "                plt.plot(episodes, stability, label=name, linewidth=2, marker='o', markersize=3)\n",
    "    plt.title('Comparison of System Stability Across Algorithms (Moving Average, Window=10)')\n",
    "    plt.xlabel('Episode')\n",
    "    plt.ylabel('Stability')\n",
    "    plt.legend()\n",
    "    plt.grid(True)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    # 3. Summary table of final results\n",
    "    summary_data = []\n",
    "    for name, logger in experiment_results.items():\n",
    "        if 'avg_reward' in logger.metrics and 'stability' in logger.metrics:\n",
    "            # Use last 20 episodes, or fewer if insufficient\n",
    "            last_n = min(20, len(logger.metrics['avg_reward']))\n",
    "            final_avg_reward = np.mean(logger.metrics['avg_reward'][-last_n:])\n",
    "            final_stability = np.mean(logger.metrics['stability'][-last_n:])\n",
    "            max_avg_reward = np.max(logger.metrics['avg_reward'])\n",
    "            max_stability = np.max(logger.metrics['stability'])\n",
    "            summary_data.append({\n",
    "                'Algorithm': name,\n",
    "                'Avg Reward (Final)': f\"{final_avg_reward:.3f}\",\n",
    "                'Stability (Final)': f\"{final_stability:.3f}\",\n",
    "                'Max Reward': f\"{max_avg_reward:.3f}\",\n",
    "                'Max Stability': f\"{max_stability:.3f}\"\n",
    "            })\n",
    "    summary_df = pd.DataFrame(summary_data)\n",
    "    print(\"Summary Table of Results (averaged over final episodes):\")\n",
    "    print(summary_df.to_string(index=False))\n",
    "    # 4. Save summary table to CSV\n",
    "    timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "    summary_df.to_csv(f\"comparison_summary_{timestamp}.csv\", index=False)\n",
    "    print(f\"\\nSummary table saved to comparison_summary_{timestamp}.csv\")"
   ],
   "metadata": {},
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [],
   "metadata": {},
   "execution_count": null,
   "outputs": []
  }
 ]
}