import os, time, random
import torch, numpy as np

from data import DataModule
from encoder_manager import create_encoder_manager
from models.baseline import load_model_api_from_cfg
from aggregations.registry import load_aggregator

from utils import (
    cfg_get, results_paths,
    atomic_write_round_csv, load_round_csv,
    print_round_logs_from_csv, set_global_seed
)

class Server:
    def __init__(self, cfg):
        self.cfg = cfg
        self.seed = int(cfg_get(cfg, "seed", 0))
        set_global_seed(self.seed)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[Server] device={self.device}")

        self.paths = results_paths(cfg)

        self.dm = DataModule(cfg)
        self.enc = create_encoder_manager(cfg, self.device)

        if hasattr(self.dm, "apply_encoder_transform") and hasattr(self.enc, "get_transform"):
            self.dm.apply_encoder_transform(self.enc.get_transform())

        if hasattr(self.enc, "can_precompute") and self.enc.can_precompute():
            self.dm.maybe_precompute(self.enc)
            
        partition_type = cfg_get(self.cfg, "data.partition.type", cfg_get(self.cfg, "data.partition", "dirichlet"))
        if partition_type == "feature_skew":
            if self.dm._train_feats is None:
                raise ValueError("Feature skew partition requires pre-computed features, but they are not available. Check encoder config.")
            self.dm.client_indices = self.dm.create_feature_skew_partition(self.dm._train_feats)
            
            min_per = int(cfg_get(self.cfg, "data.min_samples_per_client", 1))
            if min_per > 0:
                self.dm.client_indices = self.dm._ensure_min_per_client(self.dm.client_indices, min_per=min_per)

        self.model = load_model_api_from_cfg(cfg, self.device)
        self.aggregator = load_aggregator(cfg_get(cfg, "aggregator_name", "fedavg"))

        self.num_rounds = int(cfg_get(cfg, "train.num_rounds", 10))
        self.active_ratio = float(cfg_get(cfg, "train.active_client_ratio", 1.0))
        self.log_every = int(cfg_get(cfg, "train.log_every", 1))

        self.mask_missing_classes = bool(cfg_get(cfg, "train.mask_missing_classes", False))
        self.stage1_rounds = int(cfg_get(cfg, "model.num_stage1_rounds", 0))

        enc_info = self.enc.get_info() if hasattr(self.enc, "get_info") else None
        self.global_state = self.model.init_global(enc_info=enc_info)

        self.losses, self.accuracies = [], []

    def _get_state_size_bytes(self, state):
        """Helper to recursively find size of tensors in state dict."""
        total_bytes = 0
        if isinstance(state, dict):
            for v in state.values():
                total_bytes += self._get_state_size_bytes(v)
        elif torch.is_tensor(state):
            total_bytes += state.numel() * state.element_size()
        return total_bytes

    def _warmup_once(self, input_type):
        if self.device.type != "cuda":
            return
        _ = torch.empty(1, device=self.device).fill_(0.0)
        _ = torch.mm(torch.randn(32, 32, device=self.device), torch.randn(32, 32, device=self.device))
        torch.cuda.synchronize()
        try:
            testset = self.dm.get_testset(input_type=input_type)
            
            if not isinstance(testset, (tuple, list)):
                loader = DataLoader(testset, batch_size=8, shuffle=False, num_workers=0)
                x, y = next(iter(loader))
            else:
                x, y = testset

            _ = self.model.evaluate(self.global_state, (x[:8], y[:8]), enc_mgr=self.enc)
            torch.cuda.synchronize()
        except Exception:
            pass

    def run(self):
        req = self.model.get_requirements()
        input_type = req.get("input_type", "features")

        csv_path = self.paths["round_csv"]
        if os.path.exists(csv_path):
            print_round_logs_from_csv(csv_path, log_every=cfg_get(self.cfg, "train.log_every", 1))

            if bool(cfg_get(self.cfg, "train.skip_if_exists", True)):
                rec = load_round_csv(csv_path)
                self.losses = list(map(float, rec["loss"]))
                self.accuracies = list(map(float, rec["acc"]))
                print(f"[Server] existing results found for {self.paths['exp_id']} "
                      f"(rounds={len(rec['round'])}). skip training.")
                return

        client_ids = self.dm.client_ids()
        k = max(1, int(round(self.active_ratio * len(client_ids))))
        self._warmup_once(input_type)

        cum_client_time = 0.0
        cum_each_client_time = 0.0
        cum_server_time = 0.0
        rows = []
        
        cum_comm_gb = 0.0
        cum_downlink_gb = 0.0
        cum_uplink_gb = 0.0
        
        C = self.dm.num_classes

        for r in range(self.num_rounds):
            rng = random.Random(self.seed + r)
            sel = rng.sample(client_ids, k)

            round_class_counts = torch.zeros(C, dtype=torch.long)

            round_client_time = 0.0
            round_server_time = 0.0
            n_participants = 0

            buckets = []
            
            downlink_bytes_per_client = self._get_state_size_bytes(self.global_state)
            round_downlink_bytes = downlink_bytes_per_client * k
            round_uplink_bytes = 0 
            
            for cid in sel:
                data = self.dm.get_client_data(cid, input_type=input_type)

                try:
                    if isinstance(data, (tuple, list)) and len(data) == 2 and data[1] is not None:
                        y = torch.as_tensor(data[1])
                        cnt = torch.bincount(y.cpu(), minlength=C)
                        round_class_counts += cnt.to(round_class_counts.dtype)
                except Exception:
                    pass

                if self.device.type == "cuda":
                    torch.cuda.synchronize()
                t_c0 = time.perf_counter()

                upd, weights = self.model.client_update(self.global_state, data, round_idx=r, enc_mgr=self.enc)

                if self.device.type == "cuda":
                    torch.cuda.synchronize()
                round_client_time += (time.perf_counter() - t_c0)
                if upd is not None:
                    buckets.append((upd, weights))
                    n_participants += 1
                    round_uplink_bytes += self._get_state_size_bytes(upd)

            if buckets:
                if self.device.type == "cuda":
                    torch.cuda.synchronize()
                t_s0 = time.perf_counter()

                pre_state = self.global_state
                new_state = self.aggregator.aggregate(self.global_state, buckets)

                if self.mask_missing_classes and (r >= self.stage1_rounds):
                    try:
                        mask_row = (round_class_counts > 0)
                        if "head" in new_state:
                            W_new = new_state["head"]["fc.weight"]
                            b_new = new_state["head"]["fc.bias"]
                            W_old = pre_state["head"]["fc.weight"]
                            b_old = pre_state["head"]["fc.bias"]

                            m2 = mask_row.view(-1, 1).to(W_new.dtype).to(W_new.device)
                            m1 = mask_row.to(b_new.dtype).to(b_new.device)

                            new_state["head"]["fc.weight"] = W_new * m2 + W_old.to(W_new.device, dtype=W_new.dtype) * (1 - m2)
                            new_state["head"]["fc.bias"]   = b_new * m1 + b_old.to(b_new.device, dtype=b_new.dtype) * (1 - m1)
                    except Exception:
                        pass

                self.global_state = new_state

                if self.device.type == "cuda":
                    torch.cuda.synchronize()
                round_server_time = (time.perf_counter() - t_s0)

            round_comm_bytes = round_downlink_bytes + round_uplink_bytes
            
            round_comm_gb = round_comm_bytes / (1024**3)
            downlink_gb = round_downlink_bytes / (1024**3)
            uplink_gb = round_uplink_bytes / (1024**3)
            
            cum_comm_gb += round_comm_gb
            cum_downlink_gb += downlink_gb
            cum_uplink_gb += uplink_gb


            round_each_client_time = round_client_time / max(1, n_participants)

            cum_client_time += round_client_time
            cum_each_client_time += round_each_client_time
            cum_server_time += round_server_time
            cum_time = cum_client_time + cum_server_time

            testset = self.dm.get_testset(input_type=input_type)
            loss, acc = self.model.evaluate(self.global_state, testset, enc_mgr=self.enc)[:2]
            self.losses.append(float(loss)); self.accuracies.append(float(acc))

            if (r + 1) % self.log_every == 0:
                print(f">> round {r+1} | loss={loss:.4f} | acc={acc:.2f}% | "
                      f"t_round={round_client_time + round_server_time:.2f}s | "
                      f"t_cum={cum_time:.2f}s")

            rows.append((
                r+1, float(loss), float(acc),
                round_client_time, round_each_client_time, round_server_time,
                cum_client_time,  cum_each_client_time,  cum_server_time,  cum_time,
                round_comm_gb, cum_comm_gb, cum_downlink_gb, cum_uplink_gb 
            ))
        
        atomic_write_round_csv(
            self.paths["round_csv"],
            rows,
            header=[
                "round","loss","acc",
                "round_client_time_sec","round_each_client_time_sec","round_server_time_sec",
                "cum_client_time_sec","cum_each_client_time_sec","cum_server_time_sec","cum_time_sec",
                "round_comm_gb", "cum_comm_gb", "cum_downlink_gb", "cum_uplink_gb" 
            ]
        )