# train/train_agent.py
"""
Training loop for ChronosCore DQN agent with transformer encoder.
This is a compact DQN implementation with replay buffer and target network updates.
"""
import random
from collections import deque
from typing import Tuple, List
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from configs import Config
from chronoscore.encoder import ChronosEncoder
from chronoscore.q_network import QHead
from data.dataset_builder import generate_random_taskset, Task

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buf = deque(maxlen=capacity)

    def push(self, s, a, r, s2, done):
        self.buf.append((s, a, r, s2, done))

    def sample(self, batch_size):
        batch = random.sample(self.buf, batch_size)
        s, a, r, s2, done = zip(*batch)
        return s, a, r, s2, done

    def __len__(self):
        return len(self.buf)

class Agent(nn.Module):
    def __init__(self, cfg: Config, n_tasks: int):
        super().__init__()
        self.encoder = ChronosEncoder(cfg, n_tasks)
        self.qhead = QHead(cfg.latent_dim, n_tasks)

    def forward(self, state):
        # state: tuple -> features -> q-values
        feats = self.encoder(state)   # [n_tasks, latent_dim]
        qvals = self.qhead(feats)     # [n_tasks + 1]
        return qvals

def train(cfg: Config, n_tasks: int) -> Tuple[nn.Module, dict]:
    device = cfg.device
    policy = Agent(cfg, n_tasks).to(device)
    target = Agent(cfg, n_tasks).to(device)
    target.load_state_dict(policy.state_dict())
    optimizer = optim.Adam(policy.parameters(), lr=cfg.lr)
    buffer = ReplayBuffer(capacity=cfg.memory_size)
    epsilon = 1.0
    epsilon_min = 0.01
    epsilon_decay = 0.9995

    losses = []
    rewards_episode = []
    metrics = {}

    # generate training set on the fly
    for episode in range(cfg.n_tasksets_train):
        tasks = generate_random_taskset(n_tasks=n_tasks, total_utilization=cfg.total_utilization,
                                       min_period=cfg.min_period, max_period=cfg.max_period)
        from chronoscore.environment import TaskSchedulingEnvironment
        env = TaskSchedulingEnvironment(tasks=tasks, n_quanta=cfg.n_quanta)
        state = env.reset()
        total_reward = 0.0
        for t in range(env.L):
            # choose action
            if random.random() < epsilon:
                action = env.sample_random_action()
            else:
                with torch.no_grad():
                    qvals = policy(state).cpu().numpy()
                    action = int(qvals.argmax())

            next_state, reward, done = env.step(action)
            buffer.push(state, action, reward, next_state, done)
            total_reward += reward

            # experience replay
            if len(buffer) >= cfg.batch_size:
                s_batch, a_batch, r_batch, s2_batch, done_batch = buffer.sample(cfg.batch_size)
                # compute Q targets
                q_preds = []
                q_targets = []
                for s_i, a_i, r_i, s2_i, done_i in zip(s_batch, a_batch, r_batch, s2_batch, done_batch):
                    q_pred = policy(s_i)  # [n_actions]
                    with torch.no_grad():
                        q_next = target(s2_i)
                    q_target = q_pred.clone().detach()
                    if done_i:
                        q_target[a_i] = r_i
                    else:
                        q_target[a_i] = r_i + cfg.gamma * q_next.max().item()
                    q_preds.append(q_pred.unsqueeze(0))
                    q_targets.append(q_target.unsqueeze(0))
                q_preds = torch.cat(q_preds, dim=0)
                q_targets = torch.cat(q_targets, dim=0)
                loss = nn.functional.mse_loss(q_preds, q_targets)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.append(loss.item())

            state = next_state
            if done:
                break

        # target update
        if episode % cfg.update_target_every == 0:
            target.load_state_dict(policy.state_dict())

        epsilon = max(epsilon_min, epsilon * epsilon_decay)
        rewards_episode.append(total_reward)

    # return trained policy and diagnostics
    diag = {"losses": losses, "rewards": rewards_episode}
    return policy, diag
