from __future__ import annotations
import math
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
from torch.utils.data import DataLoader
from .sensor import ControlGSensor
from .planner import ControlGPlanner
from .controller import ControlGController

class ControlGTrainer:
    def __init__(self, dataset_name: str, config: Dict[str, Any], seed: int, device: str = "cuda:0",
                 output_dir: Optional[Path] = None, pretrain_mode: str = "node"):
        self.dataset_name, self.config, self.seed, self.device = dataset_name, config, seed, device
        self.output_dir, self.pretrain_mode = output_dir, pretrain_mode
        cfg = dict(config or {})
        self.hidden_dim = cfg.get("hidden_dim", [256])
        self.lr, self.weight_decay = float(cfg.get("lr", 0.001)), float(cfg.get("weight_decay", 0.0))
        self.batch_size = int(cfg.get("batch_size", 256))
        self.tasks = cfg.get("tasks", ["p_minsg", "p_decor", "p_recon", "p_par", "p_link"])
        self.K = len(self.tasks)
        self.block_size = int(cfg.get("controlg_block_size", 5))
        self.sense_period = int(cfg.get("controlg_sense_period", 5))
        self.plan_every_epochs = int(cfg.get("controlg_plan_every_epochs", 1))
        self.alpha, self.beta, self.rho = float(cfg.get("controlg_alpha", 0.5)), float(cfg.get("controlg_beta", 0.5)), float(cfg.get("controlg_rho", 0.2))
        self.D_min, self.D_max = float(cfg.get("controlg_D_min", 0.0)), float(cfg.get("controlg_D_max", 1.0))
        self.ref_delta, self.gamma, self.f_min = float(cfg.get("controlg_ref_delta", 0.1)), float(cfg.get("controlg_gamma", 1.0)), float(cfg.get("controlg_f_min", 0.05))
        self.Kp, self.Ki, self.Kd = float(cfg.get("controlg_Kp", 1.0)), float(cfg.get("controlg_Ki", 0.5)), float(cfg.get("controlg_Kd", 0.1))
        self.I_max, self.epsilon = float(cfg.get("controlg_I_max", 5.0)), float(cfg.get("controlg_epsilon", 0.1))
        self.tau0, self.tau_min, self.tau_anneal = float(cfg.get("controlg_tau0", 1.0)), float(cfg.get("controlg_tau_min", 0.1)), float(cfg.get("controlg_tau_anneal", 100.0))
        self._rng = np.random.RandomState(int(seed))
        self._epoch_idx, self._global_block, self._global_step = 0, 0, 0
        self.sensor, self.planner, self.controller = None, None, None
        self.allocation = None
        self._model, self._optimizer, self._graph = None, None, None
    
    def setup(self) -> None:
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        self.sensor = ControlGSensor(K=self.K, device=self.device, alpha=self.alpha, beta=self.beta, rho=self.rho, D_min=self.D_min, D_max=self.D_max)
        self.planner = ControlGPlanner(K=self.K, device=self.device, ref_delta=self.ref_delta, gamma=self.gamma, f_min=self.f_min)
        self.controller = ControlGController(K=self.K, device=self.device, Kp=self.Kp, Ki=self.Ki, Kd=self.Kd, I_max=self.I_max, epsilon=self.epsilon, tau0=self.tau0, tau_min=self.tau_min, tau_anneal=self.tau_anneal)
    
    def train_epoch(self) -> float:
        if self._epoch_idx % self.plan_every_epochs == 0:
            difficulty = self.sensor.get_difficulty()
            normalized_losses = self.sensor.get_normalized_losses()
            self.allocation = self.planner.plan(normalized_losses, difficulty)
            self.controller.reset()
            self.controller.set_global_block(self._global_block)
        total_loss, num_steps = 0.0, 0
        for block_idx in range(self.blocks_per_epoch):
            selected_task_idx, ctrl_info = self.controller.step(self.allocation, self.sensor.get_difficulty(), self._rng)
            selected_task = self.tasks[selected_task_idx]
            for step_in_block in range(self.block_size):
                self._global_step += 1
                num_steps += 1
            self._global_block += 1
        self._epoch_idx += 1
        return total_loss / max(1, num_steps)
    
    def get_embeddings(self) -> torch.Tensor:
        self._model.eval()
        with torch.no_grad():
            return self._model.big_model.node_module(self._graph, self._graph.ndata["feat"]).detach().cpu()
