# -*- coding: utf-8 -*-
"""
Group-to-Global alignment via KL-type f-divergence (no pairwise, no multi-head)
Each group g has its own TNet_g instance:
    D_hat_g = mean_{x~pi_g}[T_g(x)] - mean_{x~pi_all}[exp(T_g(x)-1)]
Fairness term = sum_g D_hat_g

Alternating training:
  (1) Fix classifier, update {T_g} to maximize sum_g D_hat_g
  (2) Fix {T_g}, update classifier to minimize CE + lambda * sum_g D_hat_g

Dependencies:
- optimalfair.algorithm.classifierbase.basicprocess
- optimalfair.utils.models.choose_model
- basicprocess provides: self.device, self.gpu, self.batch_size, self.lr, self.num_round,
  self.n_group, self.n_class, self.fair_metric, self.fair_evaluate(...)

Dataset:
- train_data/test_data: __getitem__ returns (x,y,a)
- data.Y, data.A are numpy arrays
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from optimalfair.algorithm.classifierbase import basicprocess
from optimalfair.utils.models import choose_model
from optimalfair.utils.models import *
from optimalfair.utils.model_utils import *


# -------------------------
# Utils
# -------------------------
def _toggle_grad(model: nn.Module, flag: bool):
    for p in model.parameters():
        p.requires_grad_(flag)


def _to_1d_long(x: torch.Tensor) -> torch.Tensor:
    return x.view(-1).long()


def _sample_indices(idx: torch.Tensor, M: int) -> torch.Tensor:
    """
    Sample M indices from idx (1D tensor).
    If idx shorter than M, sample with replacement.
    """
    n = idx.numel()
    if n <= 0 or M <= 0:
        return None
    if n >= M:
        perm = torch.randperm(n, device=idx.device)[:M]
        return idx[perm]
    ridx = torch.randint(0, n, (M,), device=idx.device)
    return idx[ridx]


def _min_pos_count(counts: torch.Tensor) -> int:
    """
    min positive count (avoid empty groups)
    """
    counts = counts.detach().cpu()
    pos = counts[counts > 0]
    if pos.numel() == 0:
        return 0
    return int(pos.min().item())


def _build_index_dict_align_global(a: torch.Tensor, n_group: int, M_g: int, M_all: int):
    """
    returns:
      index_dict[g] -> group g indices (size M_g)
      index_dict["all"] -> global indices (size M_all)
    """
    a = _to_1d_long(a)
    index_dict = {}

    for g in range(n_group):
        idx_g = torch.where(a == g)[0]
        sel_g = _sample_indices(idx_g, M_g)
        if sel_g is not None:
            index_dict[g] = sel_g

    idx_all = torch.arange(a.numel(), device=a.device)
    sel_all = _sample_indices(idx_all, M_all)
    if sel_all is not None:
        index_dict["all"] = sel_all

    return index_dict


# -------------------------
# TNet (single-output)
# -------------------------
class TNet(nn.Module):
    """
    Input:  (B, n_class)  predicted probabilities as samples from pi
    Output: (B,)          scalar T(x)
    """
    def __init__(self, input_dim, hidden_dim=5, depth=3):
        super().__init__()
        assert depth >= 2
        layers = []
        in_dim = input_dim
        for _ in range(depth - 1):
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU())
            in_dim = hidden_dim
        layers.append(nn.Linear(in_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x).squeeze(-1)


# -------------------------
# Trainer
# -------------------------
class classifier(basicprocess):
    def __init__(self, dataset, options, name=""):
        super().__init__(dataset, options, name)

        # classifier update epochs per round
        self.inner_round = int(self.options.get("inner_round", 1))

        # fairness weight
        self.lamb = float(self.options.get("fdiv_lamb", 1.0))

        # divergence estimator update steps per round
        self.inner_step = int(self.options.get("inner_step", 100))

        # divergence estimator lr
        self.inner_lr = float(self.options.get("inner_lr", 3e-3))

        # div optimizer weight decay (recommend 0 to prevent collapse-to-constant due to regularization)
        self.div_weight_decay = float(self.options.get("div_weight_decay", 0.0))

        # TNet capacity
        self.tnet_hidden = int(self.options.get("tnet_hidden", 5))
        self.tnet_depth = int(self.options.get("tnet_depth", 3))

        # sampling control: M_all = min(N, int(M_g * n_group * global_ratio))
        self.global_ratio = float(self.options.get("global_ratio", 1.0))

        # optional warmup rounds (set 0 if you don't need)
        self.warmup_rounds = int(self.options.get("warmup_rounds", 0))

        self.model_ = None
        self.div_models = None  # ModuleList[TNet]

    # KL-type conjugate: f*(t)=exp(t-1)
    def f_star(self, t: torch.Tensor) -> torch.Tensor:
        # small stabilization to avoid exp overflow
        t = torch.clamp(t, -10.0, 10.0)
        return torch.exp(t - 1.0)

    def compute_f_divergence_align_global(self, pred_logit: torch.Tensor, index_dict: dict) -> torch.Tensor:
        """
        total = sum_g [ mean(T_g(x_g)) - mean(exp(T_g(x_all)-1)) ]
        """
        pred_prob = F.softmax(pred_logit, dim=1)
        total = pred_prob.new_tensor(0.0)

        if "all" not in index_dict:
            return total

        x_all = pred_prob[index_dict["all"]]  # (M_all, C)

        for g in range(self.n_group):
            if g not in index_dict:
                continue
            x_g = pred_prob[index_dict[g]]      # (M_g, C)

            T_g = self.div_models[g]
            t_pos = T_g(x_g)                    # (M_g,)
            t_neg = T_g(x_all)                  # (M_all,)

            total = total + (t_pos.mean() - self.f_star(t_neg).mean())

        return total

    def train(self):
        # -------------------------
        # Build classifier model
        # -------------------------
        X = self.train_data.X
        input_dim = X.shape[1]
        self.model_fn = choose_model(self.options)
        self.model_kwargs = {"input_shape": input_dim, "output_dim": self.n_class}
        self.model_ = self.model_fn(**self.model_kwargs).to(self.device)

        # init logger
        run_dir = make_run_dir(self.options)
        logger = JSONLStepLogger(run_dir, config={"lr": self.lr, "bs": self.batch_size}, flush_every=10)

        # -------------------------
        # Build per-group TNet list (NO multi-head)
        # -------------------------
        self.div_models = nn.ModuleList([
            TNet(input_dim=self.n_class, hidden_dim=self.tnet_hidden, depth=self.tnet_depth).to(self.device)
            for _ in range(self.n_group)
        ])

        # -------------------------
        # DataLoader
        # -------------------------
        dl = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)

        optimizer = torch.optim.Adam(self.model_.parameters(), lr=self.lr)
        criterion = nn.CrossEntropyLoss()

        # One optimizer for all T_g parameters (ModuleList registers params)
        div_optimizer = torch.optim.AdamW(
            self.div_models.parameters(),
            lr=self.inner_lr,
            weight_decay=self.div_weight_decay
        )

        # -------------------------
        # Training loop
        # -------------------------
        for self.round in range(self.num_round):

            # =========================================================
            # (0) Optional warmup: only CE (helps if you want it)
            # =========================================================
            if self.round < self.warmup_rounds:
                self.model_.train()
                _toggle_grad(self.model_, True)
                for m in self.div_models:
                    m.eval()
                    _toggle_grad(m, False)

                for epoch in range(self.inner_round):
                    total_loss, n_batches = 0.0, 0
                    for (x, y, a) in dl:
                        if self.gpu:
                            x, y = x.to(self.device), y.to(self.device)

                        optimizer.zero_grad(set_to_none=True)
                        out = self.model_(x)
                        loss = criterion(out, _to_1d_long(y))
                        loss.backward()
                        optimizer.step()

                        total_loss += float(loss.item())
                        n_batches += 1

                    if self.verbose:
                        print(f"[Warmup] Round {self.round+1}/{self.num_round} "
                              f"Epoch {epoch+1}/{self.inner_round} loss={total_loss/max(n_batches,1):.6f}")

                test_acc, test_diff, _, _ = self.model_eval(self.test_data)
                if self.verbose:
                    print(f"[Eval]  test accuracy={test_acc:.4f}, disparity={test_diff:.4f}")
                continue

            # =========================================================
            # (1) Update {T_g}: maximize sum_g D_hat(π_g || π_all)
            # =========================================================
            self.model_.eval()
            _toggle_grad(self.model_, False)

            # compute logits on full train set once
            with torch.no_grad():
                _, _, _, pred_logit = self.model_eval(self.train_data)  # (N, C) on device

            A_t = torch.tensor(self.train_data.A, device=self.device).view(-1).long()
            counts = torch.bincount(A_t, minlength=self.n_group)

            M_g = _min_pos_count(counts)
            if M_g <= 0:
                if self.verbose:
                    print("[Warn] M_g <= 0, skip divergence update this round.")
            else:
                N = A_t.numel()
                M_all = int(min(N, max(1, int(M_g * self.n_group * self.global_ratio))))
                index_dict = _build_index_dict_align_global(A_t, self.n_group, M_g=M_g, M_all=M_all)

                # enable grads for T_g
                for m in self.div_models:
                    m.train()
                    _toggle_grad(m, True)

                for step in range(self.inner_step):
                    div_optimizer.zero_grad(set_to_none=True)
                    total_div = self.compute_f_divergence_align_global(pred_logit, index_dict)
                    loss_div = -total_div
                    loss_div.backward()
                    div_optimizer.step()

                    if self.verbose and (step % 20 == 0):
                        # debug one group
                        with torch.no_grad():
                            g0 = 0
                            if g0 in index_dict:
                                x_dbg = F.softmax(pred_logit[index_dict[g0]], dim=1)
                                t_dbg = self.div_models[g0](x_dbg)
                                t_mean, t_std = t_dbg.mean().item(), t_dbg.std().item()
                            else:
                                t_mean, t_std = 0.0, 0.0
                        print(f"[Div] step {step}/{self.inner_step} | total_div {total_div.item():.6f} "
                              f"| T0(mean/std) {t_mean:.6f}/{t_std:.6f} | M_g={M_g} M_all={M_all}")

            # =========================================================
            # (2) Update classifier: minimize CE + lamb * sum_g D_hat
            # =========================================================
            self.model_.train()
            _toggle_grad(self.model_, True)

            # freeze {T_g} params (but keep graph w.r.t. model_ output)
            for m in self.div_models:
                m.eval()
                _toggle_grad(m, False)

            for epoch in range(self.inner_round):
                total_loss, n_batches = 0.0, 0
                for (x, y, a) in dl:
                    if self.gpu:
                        x, y, a = x.to(self.device), y.to(self.device), a.to(self.device)

                    optimizer.zero_grad(set_to_none=True)
                    out = self.model_(x)  # (B, C)

                    # build batch index_dict
                    a_1d = _to_1d_long(a)
                    counts_b = torch.bincount(a_1d, minlength=self.n_group)
                    M_g_b = _min_pos_count(counts_b)

                    if M_g_b <= 0:
                        fair_term = out.new_tensor(0.0)
                    else:
                        B = a_1d.numel()
                        M_all_b = int(min(B, max(1, int(M_g_b * self.n_group * self.global_ratio))))
                        index_dict_b = _build_index_dict_align_global(a_1d, self.n_group, M_g=M_g_b, M_all=M_all_b)
                        fair_term = self.compute_f_divergence_align_global(out, index_dict_b)

                    loss = criterion(out, _to_1d_long(y)) + self.lamb * fair_term
                    loss.backward()
                    optimizer.step()

                    total_loss += float(loss.item())
                    n_batches += 1

                if self.verbose:
                    print(f"[Train] Round {self.round+1}/{self.num_round} "
                          f"Epoch {epoch+1}/{self.inner_round} loss={total_loss/max(n_batches,1):.6f}")

            # =========================================================
            # (3) Eval
            # =========================================================
            test_acc, test_diff, _, _ = self.model_eval(self.test_data)
            if self.verbose:
                print(f"[Eval]  test accuracy={test_acc:.4f}, disparity={test_diff:.4f}")
            
            logger.log_step(round=self.round, metrics={"acc": float(test_acc) ,"fairness_level": float(test_diff)},)

            train_acc, train_diff, _, _ = self.model_eval(self.train_data)
            if self.verbose:
                print(f"[Eval] train accuracy={train_acc:.4f}, disparity={train_diff:.4f}")

        # final
        self.model_.eval()
        test_acc, test_diff, _, _ = self.model_eval(self.test_data)
        if self.verbose:
            print(f"[Final] test accuracy={test_acc:.4f}, disparity={test_diff:.4f}")
        logger.log_step(round='final', metrics={"acc": float(test_acc) ,"fairness_level": float(test_diff)},)

    @torch.no_grad()
    def model_eval(self, data, ensemble=False, round=None):
        assert self.model_ is not None

        dataLoader = DataLoader(data, batch_size=self.batch_size, shuffle=False)
        self.model_.eval()

        test_correct = 0.0
        test_num = 0.0
        preds = []
        logits_list = []

        for (x, y, a) in dataLoader:
            if self.gpu:
                x, y = x.to(self.device), y.to(self.device)

            if not ensemble:
                pred = self.model_(x)
                logits_list.append(pred)

                _, predicted = torch.max(pred, 1)
                preds.append(predicted.detach().cpu())

                correct = predicted.eq(y.view(-1).long()).sum().item()
                bs = y.size(0)
                test_correct += correct
                test_num += bs
            else:
                raise NotImplementedError("ensemble=True is not implemented.")

        test_acc = test_correct / max(test_num, 1.0)
        pred_class = torch.cat(preds, dim=0).numpy()
        pred_logit = torch.cat(logits_list, dim=0)

        diff, matrix = self.fair_evaluate(
            Y=data.Y.ravel(),
            pred_Y=pred_class.ravel(),
            A=data.A.ravel()
        )
        return test_acc, diff, matrix, pred_logit
