from __future__ import annotations

import os
import random
from collections import deque
from typing import List, Optional, Sequence, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from .agent import Agent
from .groupemb import build_group_embedding
from .agentemb import build_agent_embedding
from .skill_q_networks import build_skill_q_network


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_DEVICE_LOGGED = False


class SlightAgent(Agent):
    def __init__(self, dic_agent_conf, dic_traffic_env_conf, dic_path, cnt_round=None, intersection_id="0"):
        super(SlightAgent, self).__init__(dic_agent_conf, dic_traffic_env_conf, dic_path, intersection_id)

        self.device = DEVICE
        global _DEVICE_LOGGED
        if not _DEVICE_LOGGED:
            visible = os.environ.get("CUDA_VISIBLE_DEVICES", "all")
            print(f"[SlightAgent] Using device {self.device} (CUDA_VISIBLE_DEVICES={visible})")
            if torch.cuda.is_available():
                try:
                    current_idx = torch.cuda.current_device()
                    num_devices = torch.cuda.device_count()
                    name = torch.cuda.get_device_name(current_idx)
                    print(f"[SlightAgent] torch.cuda.current_device={current_idx}, count={num_devices}, name={name}")
                except Exception as exc:
                    print(f"[SlightAgent] Could not query cuda device info: {exc}")
            _DEVICE_LOGGED = True

        self.num_skills = int(dic_agent_conf.get("NUM_SKILLS", 4))
        self.meta_controller_freq = int(dic_agent_conf.get("META_CONTROLLER_FREQ", 10))
        self.num_agents = dic_traffic_env_conf["NUM_INTERSECTIONS"]
        self.num_neighbors = min(dic_traffic_env_conf["TOP_K_ADJACENCY"], self.num_agents)
        self.num_actions = len(self.dic_traffic_env_conf["PHASE"])
        self.len_feature = self._cal_len_feature()
        self.meta_input_dim = self.len_feature + self.num_actions + 1

        self.memory_skills = [[[] for _ in range(self.num_agents)] for _ in range(self.num_skills)]
        max_buffer = int(self.dic_agent_conf.get("GROUP_EMBED_BUFFER_SIZE", self.dic_agent_conf.get("CVAE_BUFFER_SIZE", 2000)))
        self.group_embed_memory = [deque(maxlen=max_buffer) for _ in range(self.num_skills)]

        self._skill_mlp_layers = dic_agent_conf.get("MLP_LAYERS", [32, 32])
        self._skill_cnn_layers = dic_agent_conf.get("CNN_layers", [[32, 32]])
        self.skill_q_type = str(dic_agent_conf.get("SKILL_Q_TYPE", "colight_gat")).lower()
        self.skill_q_num_heads = int(dic_agent_conf.get("SKILL_Q_NUM_HEADS", 5))

        self.skill_q_networks = [self._build_skill_network() for _ in range(self.num_skills)]
        self.skill_q_networks_bar = [self.build_network_from_copy(net) for net in self.skill_q_networks]

        for net in self.skill_q_networks + self.skill_q_networks_bar:
            net.to(self.device)

        if cnt_round and cnt_round > 0:
            try:
                self.load_network(f"round_{cnt_round - 1}_inter_{self.intersection_id}")
            except Exception:
                print(f"fail to load network, current round: {cnt_round}")

        lr = float(self.dic_agent_conf.get("LEARNING_RATE", 1e-3))
        self.skill_optimizers = [torch.optim.Adam(net.parameters(), lr=lr) for net in self.skill_q_networks]
        self.skill_loss = nn.MSELoss()
        self.max_grad_norm = float(self.dic_agent_conf.get("MAX_GRAD_NORM", 5.0))

        default_latent = int(self.dic_agent_conf.get("CVAE_LATENT_DIM", 16))
        self.group_latent_dim = int(self.dic_agent_conf.get("GROUP_LATENT_DIM", default_latent))
        self.agent_embed_dim = int(self.dic_agent_conf.get("AGENT_EMBED_DIM", self.group_latent_dim))

        requested_embed = str(self.dic_agent_conf.get("AGENT_EMBED_TYPE", "lstm")).lower()
        if requested_embed != "lstm":
            raise ValueError("Only 'lstm' agent embedding is supported in the cleaned codebase")
        self.agent_embed_type = "lstm"

        lstm_hidden = int(self.dic_agent_conf.get("LSTM_EMBED_HIDDEN", 128))
        lstm_layers = int(self.dic_agent_conf.get("LSTM_EMBED_LAYERS", 1))
        lstm_dropout = float(self.dic_agent_conf.get("LSTM_EMBED_DROPOUT", 0.0))
        agent_embed_kwargs = {
            "input_dim": self.meta_input_dim,
            "hidden_dim": lstm_hidden,
            "embed_dim": self.agent_embed_dim,
            "num_layers": lstm_layers,
            "dropout": lstm_dropout,
        }
        self.agent_embed_model = build_agent_embedding(self.agent_embed_type, **agent_embed_kwargs).to(self.device)
        self.agent_embed_model.eval()

        use_cvae = bool(self.dic_agent_conf.get("USE_CVAE", True))
        default_group_type = "cvae" if use_cvae else "none"
        requested_group = str(self.dic_agent_conf.get("GROUP_EMBED_TYPE", default_group_type)).lower()
        if requested_group not in ("cvae", "none"):
            raise ValueError("Only 'cvae' group embedding is supported in the cleaned codebase")
        self.group_embed_type = requested_group

        self.group_embed_model: Optional[nn.Module] = None
        if self.group_embed_type == "cvae":
            self.group_embed_model = build_group_embedding(
                "cvae",
                state_dim=self.len_feature,
                action_dim=self.num_actions,
                latent_dim=self.group_latent_dim,
                lr=float(self.dic_agent_conf.get("CVAE_LEARNING_RATE", 1e-3)),
                kl_weight=float(self.dic_agent_conf.get("CVAE_KL_WEIGHT", 1e-4)),
                hidden_dim=int(self.dic_agent_conf.get("CVAE_HIDDEN_DIM", 128)),
                device=self.device,
            )
            self.group_embed_model.eval()

        self.group_interaction = nn.Linear(
            self.num_agents + self.agent_embed_dim,
            self.group_latent_dim,
            bias=False,
        ).to(self.device)
        self.group_interaction.eval()

        decayed_epsilon = self.dic_agent_conf["EPSILON"] * pow(
            self.dic_agent_conf["EPSILON_DECAY"], cnt_round if cnt_round is not None else 0
        )
        self.dic_agent_conf["EPSILON"] = max(decayed_epsilon, self.dic_agent_conf["MIN_EPSILON"])

        self.history_buffer: deque = deque(maxlen=self.meta_controller_freq)
        self.update_counter = 0
        self.current_skill = -1
        self._meta_acc_reward = 0.0
        self._last_group_assignments: Optional[np.ndarray] = None
        self._last_state_features: Optional[np.ndarray] = None

    def _build_skill_network(self) -> nn.Module:
        return build_skill_q_network(
            self.skill_q_type,
            len_feature=self.len_feature,
            mlp_layers=self._skill_mlp_layers,
            cnn_layers=self._skill_cnn_layers,
            num_actions=self.num_actions,
            num_agents=self.num_agents,
            num_neighbors=self.num_neighbors,
            num_heads=self.skill_q_num_heads,
        )

    def build_network_from_copy(self, network_copy: nn.Module) -> nn.Module:
        clone = self._build_skill_network()
        clone.load_state_dict(network_copy.state_dict())
        return clone

    def _cal_len_feature(self):
        N = 0
        used_feature = self.dic_traffic_env_conf["LIST_STATE_FEATURE"][:-1]
        for feat_name in used_feature:
            if "cur_phase" in feat_name:
                N += 8
            else:
                N += 12
        return N

    def adjacency_index2matrix(self, adjacency_index: np.ndarray) -> np.ndarray:
        adjacency_index = np.asarray(adjacency_index, dtype=object)
        if adjacency_index.ndim < 3:
            raise ValueError("adjacency_index must be [batch, agents, neighbors or agents]")

        batch, num_agents = adjacency_index.shape[:2]
        mat = np.zeros((batch, num_agents, self.num_neighbors, num_agents), dtype=np.float32)
        for b in range(batch):
            for i in range(num_agents):
                row = np.asarray(adjacency_index[b, i])
                if row.size == num_agents and np.all((row == 0) | (row == 1)):
                    neighbors = np.flatnonzero(row)[: self.num_neighbors]
                else:
                    neighbors = row.astype(np.int64, copy=False).ravel()
                neighbors = neighbors[(neighbors >= 0) & (neighbors < num_agents)]
                for k, neigh in enumerate(neighbors[: self.num_neighbors]):
                    mat[b, i, k, neigh] = 1.0
        return mat

    def convert_state_to_input(self, s) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        used_feature = self.dic_traffic_env_conf["LIST_STATE_FEATURE"][:-1]
        feats0 = []
        adj = []
        for i in range(self.num_agents):
            adj.append(s[i]["adjacency_matrix"])
            tmp = []
            for feature in used_feature:
                if feature == "cur_phase":
                    if self.dic_traffic_env_conf["BINARY_PHASE_EXPANSION"]:
                        tmp.extend(self.dic_traffic_env_conf["PHASE"][s[i][feature][0]])
                    else:
                        tmp.extend(s[i][feature])
                else:
                    tmp.extend(s[i][feature])
            feats0.append(tmp)
        feats = np.array([feats0], dtype=np.float32)
        adj_neighbor = self.adjacency_index2matrix(np.array([adj], dtype=object))
        adj_matrix = adj_neighbor.max(axis=2)
        eye = np.eye(self.num_agents, dtype=np.float32)
        adj_matrix = np.maximum(adj_matrix, adj_matrix.transpose(0, 2, 1))
        adj_matrix = np.maximum(adj_matrix, eye.reshape(1, self.num_agents, self.num_agents))
        return feats, adj_neighbor, adj_matrix

    def choose_action(self, count, states):
        feats_np, adj_np, adj_mat_np = self.convert_state_to_input(states)
        state_tensor = torch.from_numpy(feats_np).to(self.device)
        adj_tensor = torch.from_numpy(adj_np).to(self.device)
        adj_mat_tensor = torch.from_numpy(adj_mat_np).to(self.device)
        _ = count
        _ = adj_mat_tensor

        meta_state = self._get_meta_state_from_buffer()
        with torch.no_grad():
            meta_state_tensor = torch.from_numpy(meta_state).to(self.device)
            agent_embeds = self.agent_embed_model(meta_state_tensor).squeeze(0).cpu().numpy()

        if self.group_embed_model is not None:
            Z = []
            for gid in range(self.num_skills):
                Z.append(self._sample_skill_latent(gid))
            Z = np.stack(Z, axis=0)
        else:
            Z = np.random.normal(size=(self.num_skills, self.group_latent_dim)).astype(np.float32)

        identity = np.eye(self.num_agents, dtype=np.float32)
        X = np.concatenate([identity, agent_embeds.astype(np.float32)], axis=1)
        X_tensor = torch.from_numpy(X).to(self.device)
        Z_tensor = torch.from_numpy(Z).to(self.device)
        with torch.no_grad():
            interaction = self.group_interaction(X_tensor)
            compat_tensor = torch.matmul(interaction, Z_tensor.transpose(0, 1))
        compat = compat_tensor.cpu().numpy()
        compat = compat - np.max(compat, axis=1, keepdims=True)
        probs_group = np.exp(compat) / np.clip(np.sum(np.exp(compat), axis=1, keepdims=True), 1e-6, None)

        eps = float(self.dic_agent_conf.get("EPSILON", 0.1)) / max(1.0, float(self.num_skills))
        group_ids = []
        for i in range(self.num_agents):
            if random.random() < eps:
                gid = np.random.randint(self.num_skills)
            else:
                gid = int(np.argmax(probs_group[i]))
            group_ids.append(gid)
        group_ids = np.asarray(group_ids, dtype=np.int64)

        q_values_all = []
        with torch.no_grad():
            for net in self.skill_q_networks:
                net.eval()
                q_values_all.append(net(state_tensor, adj_tensor).cpu().numpy())

        q_pick = np.zeros((self.num_agents, self.num_actions), dtype=np.float32)
        for i in range(self.num_agents):
            q_pick[i] = q_values_all[group_ids[i]][0, i]

        if random.random() <= self.dic_agent_conf["EPSILON"] / 5:
            action = np.random.randint(self.num_actions, size=self.num_agents)
        else:
            action = np.argmax(q_pick, axis=1)

        action_onehot = np.zeros((self.num_agents, self.num_actions), dtype=np.float32)
        for i in range(self.num_agents):
            action_onehot[i, action[i]] = 1.0

        self._last_group_assignments = group_ids.copy()
        self._last_state_features = feats_np[0].copy()
        self._record_history(feats_np[0], action_onehot)

        return action

    def _get_meta_state_from_buffer(self) -> np.ndarray:
        if len(self.history_buffer) == 0:
            entries = [self._empty_history_entry() for _ in range(self.meta_controller_freq)]
        else:
            seq = list(self.history_buffer)
            if len(seq) < self.meta_controller_freq:
                pad_len = self.meta_controller_freq - len(seq)
                seq = [self._empty_history_entry() for _ in range(pad_len)] + seq
            else:
                seq = seq[-self.meta_controller_freq:]
            entries = seq

        stacked = []
        for item in entries:
            state = item["state"].astype(np.float32)
            action = item["action"] if item["action"] is not None else np.zeros((self.num_agents, self.num_actions), dtype=np.float32)
            reward = item["reward"] if item["reward"] is not None else np.zeros((self.num_agents, 1), dtype=np.float32)
            combined = np.concatenate([state, action.astype(np.float32), reward.astype(np.float32)], axis=1)
            stacked.append(combined)

        arr = np.stack(stacked, axis=0)
        return np.expand_dims(arr.astype(np.float32), axis=0)

    def _empty_history_entry(self) -> dict:
        return {
            "state": np.zeros((self.num_agents, self.len_feature), dtype=np.float32),
            "action": np.zeros((self.num_agents, self.num_actions), dtype=np.float32),
            "reward": np.zeros((self.num_agents, 1), dtype=np.float32),
        }

    def _record_history(self, state_features: np.ndarray, action_onehot: np.ndarray) -> None:
        entry = {
            "state": state_features.astype(np.float32),
            "action": action_onehot.astype(np.float32),
            "reward": None,
        }
        self.history_buffer.append(entry)

    def _sample_skill_latent(self, skill_id: int) -> np.ndarray:
        if self.group_embed_model is None or not (0 <= skill_id < self.num_skills):
            return np.random.normal(size=self.group_latent_dim).astype(np.float32)

        memory = self.group_embed_memory[skill_id]
        if len(memory) == 0:
            return np.random.normal(size=self.group_latent_dim).astype(np.float32)

        sample_size = min(
            len(memory),
            int(self.dic_agent_conf.get("GROUP_EMBED_SAMPLE_SIZE", self.dic_agent_conf.get("CVAE_SAMPLE_SIZE", 32))),
        )
        indices = np.random.choice(len(memory), sample_size, replace=False)
        states = np.stack([memory[idx][0] for idx in indices], axis=0)
        actions = np.array([memory[idx][1] for idx in indices], dtype=np.int64)

        states_tensor = torch.as_tensor(states, dtype=torch.float32, device=self.device)
        actions_tensor = torch.as_tensor(actions, dtype=torch.long, device=self.device)

        with torch.no_grad():
            mu = self.group_embed_model.encode(states_tensor, actions_tensor)
        if mu.numel() == 0:
            return np.random.normal(size=self.group_latent_dim).astype(np.float32)
        return mu.mean(dim=0).detach().cpu().numpy().astype(np.float32)

    def _append_group_embed_sample(self, skill_id: int, state_vec: np.ndarray, action_idx: int) -> None:
        if self.group_embed_model is None or not (0 <= skill_id < self.num_skills):
            return
        self.group_embed_memory[skill_id].append((state_vec.astype(np.float32), int(action_idx)))

    @staticmethod
    def _concat_list(ls: Sequence[Sequence]) -> List:
        tmp = []
        for item in ls:
            tmp += list(item)
        return [tmp]

    def prepare_Xs_Y(self, samples_list):
        try:
            for skill_id in range(self.num_skills):
                self.memory_skills[skill_id] = samples_list
        except Exception:
            pass

    def prepare_Xs_Y_skill(self, memory, skill_id):
        if len(memory[0]) == 0:
            return None

        slice_size = len(memory[0])
        _adjs = []
        _state = [[] for _ in range(self.num_agents)]
        _next_state = [[] for _ in range(self.num_agents)]
        _action = [[] for _ in range(self.num_agents)]
        _reward = [[] for _ in range(self.num_agents)]

        used_feature = self.dic_traffic_env_conf["LIST_STATE_FEATURE"][:-1]

        for i in range(slice_size):
            _adj = []
            for j in range(self.num_agents):
                state, action, next_state, reward, _, _, _ = memory[j][i]
                _action[j].append(action)
                _reward[j].append(reward)
                _adj.append(state["adjacency_matrix"])
                _state[j].append(self._concat_list([state[feat] for feat in used_feature]))
                _next_state[j].append(self._concat_list([next_state[feat] for feat in used_feature]))
            _adjs.append(_adj)

        _adjs2 = self.adjacency_index2matrix(np.array(_adjs))
        _state2 = np.concatenate([np.array(ss) for ss in _state], axis=1)
        _next_state2 = np.concatenate([np.array(ss) for ss in _next_state], axis=1)

        state_tensor = torch.from_numpy(_state2).float().to(self.device)
        next_state_tensor = torch.from_numpy(_next_state2).float().to(self.device)
        adj_tensor = torch.from_numpy(_adjs2).float().to(self.device)

        with torch.no_grad():
            target = self.skill_q_networks[skill_id](state_tensor, adj_tensor)
            next_state_qvalues = self.skill_q_networks_bar[skill_id](next_state_tensor, adj_tensor)

        final_target = target.clone()
        for i in range(slice_size):
            for j in range(self.num_agents):
                a = _action[j][i]
                final_target[i, j, a] = (
                    _reward[j][i] / self.dic_agent_conf["NORMAL_FACTOR"]
                    + self.dic_agent_conf["GAMMA"] * torch.max(next_state_qvalues[i, j])
                )

        return (
            state_tensor.cpu(),
            adj_tensor.cpu(),
            final_target.cpu(),
        )

    def train_network(self):
        print("--- Training SLight Agent (PyTorch) ---")
        epochs = int(self.dic_agent_conf.get("EPOCHS", 5))
        batch_size = int(self.dic_agent_conf.get("BATCH_SIZE", 32))

        for skill_id in range(self.num_skills):
            data = self.prepare_Xs_Y_skill(self.memory_skills[skill_id], skill_id)
            if data is None:
                continue

            states, adjacency, targets = data
            dataset = TensorDataset(states, adjacency, targets)
            loader = DataLoader(dataset, batch_size=min(batch_size, len(dataset)), shuffle=True)

            net = self.skill_q_networks[skill_id]
            net.train()

            for _ in range(max(1, epochs)):
                for state_batch, adj_batch, target_batch in loader:
                    state_batch = state_batch.to(self.device)
                    adj_batch = adj_batch.to(self.device)
                    target_batch = target_batch.to(self.device)

                    self.skill_optimizers[skill_id].zero_grad()
                    pred = net(state_batch, adj_batch)
                    loss = self.skill_loss(pred, target_batch)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(net.parameters(), self.max_grad_norm)
                    self.skill_optimizers[skill_id].step()

            self.skill_q_networks_bar[skill_id].load_state_dict(net.state_dict())

        if self.group_embed_model is not None:
            samples: List[Tuple[np.ndarray, int]] = []
            for memory in self.group_embed_memory:
                samples.extend(memory)

            if samples:
                batch = min(
                    int(self.dic_agent_conf.get("GROUP_EMBED_BATCH_SIZE", self.dic_agent_conf.get("CVAE_BATCH_SIZE", 64))),
                    len(samples),
                )
                idx = np.random.choice(len(samples), batch, replace=False)
                state_batch = torch.as_tensor(
                    np.stack([samples[i][0] for i in idx], axis=0), dtype=torch.float32, device=self.device
                )
                action_batch = torch.as_tensor(
                    np.array([samples[i][1] for i in idx], dtype=np.int64), dtype=torch.long, device=self.device
                )
                self.group_embed_model.train()
                self.group_embed_model.train_batch(state_batch, action_batch)
                self.group_embed_model.eval()

        self.update_counter += 1

    def save_network(self, file_name):
        base_dir = self.dic_path.get("PATH_TO_MODEL", ".")
        os.makedirs(base_dir, exist_ok=True)

        for idx, net in enumerate(self.skill_q_networks):
            path = os.path.join(base_dir, f"{file_name}_skill_{idx}.pt")
            torch.save(net.state_dict(), path)

        embed_path = os.path.join(base_dir, f"{file_name}_embed.pt")
        torch.save(self.agent_embed_model.state_dict(), embed_path)

        if self.group_embed_model is not None:
            group_path = os.path.join(base_dir, f"{file_name}_group_embed.pt")
            self.group_embed_model.save(group_path)

    def load_network(self, file_name, file_path=None):
        if file_path is None:
            file_path = self.dic_path.get("PATH_TO_MODEL", ".")

        for idx, net in enumerate(self.skill_q_networks):
            path = os.path.join(file_path, f"{file_name}_skill_{idx}.pt")
            if os.path.exists(path):
                state = torch.load(path, map_location=self.device)
                net.load_state_dict(state)
                self.skill_q_networks_bar[idx].load_state_dict(state)

        embed_path = os.path.join(file_path, f"{file_name}_embed.pt")
        if os.path.exists(embed_path):
            state = torch.load(embed_path, map_location=self.device)
            self.agent_embed_model.load_state_dict(state)

        if self.group_embed_model is not None:
            group_paths = [
                os.path.join(file_path, f"{file_name}_group_embed.pt"),
                os.path.join(file_path, f"{file_name}_cvae.pt"),
            ]
            for path in group_paths:
                if os.path.exists(path):
                    self.group_embed_model.load(path)
                    break

    def store_step(self, states, actions, next_states, rewards, done=False):
        _ = done
        skill_id = int(self.current_skill if self.current_skill is not None else -1)
        if skill_id < 0 or skill_id >= self.num_skills:
            return
        for j in range(self.num_agents):
            self.memory_skills[skill_id][j].append(
                (states[j], int(actions[j]), next_states[j], float(rewards[j]), None, None, None)
            )
        try:
            self._meta_acc_reward += float(np.sum(rewards))
        except Exception:
            self._meta_acc_reward += float(np.mean(rewards)) * self.num_agents

        if len(self.history_buffer) > 0:
            reward_arr = np.asarray(rewards, dtype=np.float32).reshape(self.num_agents, 1)
            self.history_buffer[-1]["reward"] = reward_arr

        if (
            self.group_embed_model is not None
            and self._last_group_assignments is not None
            and self._last_state_features is not None
        ):
            action_arr = np.asarray(actions, dtype=np.int64)
            for agent_idx in range(self.num_agents):
                skill = int(self._last_group_assignments[agent_idx])
                state_vec = self._last_state_features[agent_idx]
                self._append_group_embed_sample(skill, state_vec, int(action_arr[agent_idx]))

            self._last_group_assignments = None
            self._last_state_features = None
