import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from optimalfair.algorithm.classifierbase import basicprocess
from optimalfair.utils.models import choose_model
from optimalfair.utils.models import *
from optimalfair.utils.model_utils import *
from math import sqrt

class _AdvMLP(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, hidden: int = 64, layer: int = 1):
        """
        Args:
            in_dim:  input feature dimension
            out_dim: output dimension
            hidden:  hidden width
            layer:   number of hidden layers (Linear+Activation blocks)
                     - layer=0: single Linear(in_dim -> out_dim)
                     - layer>=1: [Linear(in_dim->hidden)+Act] + (layer-1)*[Linear(hidden->hidden)+Act] + Linear(hidden->out_dim)
        """
        super().__init__()
        assert layer >= 0, "layer must be >= 0"

        layers = []
        act = nn.LeakyReLU()

        if layer == 0:
            layers.append(nn.Linear(in_dim, out_dim))
        else:
            # first hidden layer
            layers += [nn.Linear(in_dim, hidden), act]
            # extra hidden layers
            for _ in range(layer - 1):
                layers += [nn.Linear(hidden, hidden), act]
            # output layer
            layers.append(nn.Linear(hidden, out_dim))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class classifier(basicprocess):
    def __init__(self, dataset, options, name=""):
        super().__init__(dataset, options, name)
        self.use_schedule = True

        self.use_schedule = bool(options.get("use_schedule", True))
        self.alpha0 = float(options.get("alpha0", 0.3)) # initial alpha for schedule
        self.alpha = float(options.get("alpha", 1.0)) # final alpha if no schedule
        self.lr_gamma = float(options.get("lr_gamma", 0.995))
        self.early_stop = bool(options.get("early_stop", False))

    def train(self):

        # init logger
        run_dir = make_run_dir(self.options)
        logger = JSONLStepLogger(run_dir, config={"lr": self.lr, "bs": self.batch_size})


        # -----------------------
        # 0) prepare predictor
        # -----------------------
        input_dim = self.train_data.X.shape[1] 
        predictor = choose_model(self.options)(input_shape=input_dim, output_dim=self.n_class)
        self.model_ = predictor.to(self.device)

        # -----------------------
        # 1) prepare adversary
        # dp: adv input = logits (n_class)
        # eo: adv input = [logits, onehot(y)] (n_class + n_class)
        # -----------------------
        if self.fair_metric == "dp":
            adv_in_dim = self.n_class
            use_y_in_adv = False
        elif self.fair_metric in ["eo", "eop"]:
            # eo / eop ->  equalized-odds
            adv_in_dim = self.n_class + self.n_class
            use_y_in_adv = True
        else:
            raise ValueError(f"AdvDebias not implemented for fairness metric {self.fair_metric}")

        if self.options['data'] == 'drug':
            adv_hidden = int(self.options.get("adv_hidden", 64))
            adv_layer = int(self.options.get("adv_layer", 1))
        elif self.options['data'] == 'adult':
            adv_hidden = int(self.options.get("adv_hidden", 200))
            adv_layer = int(self.options.get("adv_layer", 1))
        elif self.options['data'] == 'enem':
            adv_hidden = int(self.options.get("adv_hidden", 200))
            adv_layer = int(self.options.get("adv_layer", 2))
        elif self.options['data'] == 'celeba':
            adv_hidden = int(self.options.get("adv_hidden", 200))
            adv_layer = int(self.options.get("adv_layer", 2))
        elif self.options['data'] == 'acs':
            adv_hidden = int(self.options.get("adv_hidden", 200))
            adv_layer = int(self.options.get("adv_layer", 2))
        else:
            raise ValueError(f"Dataset {self.options['data']} not recognized for AdvDebias")
        self.adversary_ = _AdvMLP(adv_in_dim, self.n_group, hidden=adv_hidden, layer=adv_layer).to(self.device)

        # -----------------------
        # 2) optimizers + schedulers
        # -----------------------
        opt_pred = torch.optim.AdamW(self.model_.parameters(), lr=float(self.lr))
        opt_adv = torch.optim.AdamW(self.adversary_.parameters(), lr=float(self.lr))

        pred_sch = torch.optim.lr_scheduler.ExponentialLR(opt_pred, gamma=float(self.lr_gamma)) if self.use_schedule else None
        adv_sch  = torch.optim.lr_scheduler.ExponentialLR(opt_adv,  gamma=float(self.lr_gamma)) if self.use_schedule else None

        alpha0 = float(self.alpha0)
        alpha_fixed = float(self.alpha)
        alpha_schedule = bool(self.use_schedule) 

        adv_steps = int(self.options.get("adv_steps", 1))   
        pred_steps = int(self.options.get("pred_steps", 1))

        dl = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)

        # -----------------------
        # 3) training loop
        # -----------------------
        global_step = 0
        for epoch in range(int(self.num_round)):
            self.model_.train()
            self.adversary_.train()

            for (x, y, a) in dl:
                global_step += 1
                if self.gpu:
                    x, y, a = x.to(self.device), y.to(self.device), a.to(self.device)

                y = y.view(-1).long()
                a = a.view(-1).long()

                # alpha: fixed or sqrt(step) schedule (matching notebook style)
                alpha = alpha0 * sqrt(max(global_step, 1)) if alpha_schedule else alpha_fixed

                # ---------- (A) update adversary: minimize CE(adv( stopgrad(logits), y ), a) ----------
                for _ in range(adv_steps):
                    with torch.no_grad():
                        logits = self.model_(x)  # (B, C)

                    if use_y_in_adv:
                        y_oh = F.one_hot(y, num_classes=self.n_class).float()
                        adv_in = torch.cat([logits, y_oh], dim=1).detach()
                    else:
                        adv_in = logits.detach()

                    adv_logits = self.adversary_(adv_in)
                    loss_adv = F.cross_entropy(adv_logits, a)

                    opt_adv.zero_grad(set_to_none=True)
                    loss_adv.backward()
                    opt_adv.step()

                # ---------- (B) update predictor: minimize CE(y) - alpha * CE_adv(a) ----------
                for _ in range(pred_steps):
                    logits = self.model_(x)
                    loss_pred = F.cross_entropy(logits, y)

                    if use_y_in_adv:
                        y_oh = F.one_hot(y, num_classes=self.n_class).float()
                        adv_in = torch.cat([logits, y_oh], dim=1)
                    else:
                        adv_in = logits

                    adv_logits = self.adversary_(adv_in)
                    loss_adv_for_pred = F.cross_entropy(adv_logits, a)

                    loss_total = loss_pred - alpha * loss_adv_for_pred

                    opt_pred.zero_grad(set_to_none=True)
                    loss_total.backward()
                    opt_pred.step()

            # schedulers per-epoch
            if self.use_schedule:
                pred_sch.step()
                adv_sch.step()

            # ---------- eval ----------
            if (epoch + 1) % int(self.eval_round) == 0:
                val_acc, val_diff, _ = self.evaluate(split="val")
                test_acc, test_diff, _ = self.evaluate(split="test")
                if self.verbose:
                    print(f"[Epoch {epoch+1}] val_acc={val_acc:.4f}, val_diff={val_diff:.4f} | "
                        f"test_acc={test_acc:.4f}, test_diff={test_diff:.4f}")
                logger.log_step(round=epoch, metrics={"acc": float(test_acc) ,"fairness_level": float(test_diff)},)

                # optional early stop
                if self.early_stop:
                    if (val_diff < self.fair_bound):
                        if self.verbose:
                            print("[EarlyStop] triggered.")
                        break

        if self.verbose:
            test_acc, test_diff, _ = self.evaluate(split="test")
            print(f"[Final] test_acc={test_acc:.4f}, test_diff={test_diff:.4f}")
        logger.log_step(round='final', metrics={"acc": float(test_acc) ,"fairness_level": float(test_diff)},)

    def _predict_classes_from_logits(self, logits: torch.Tensor):
        if logits.ndim == 1:
            # (N,)
            prob = torch.sigmoid(logits)
            return (prob >= 0.5).long()
        if logits.ndim == 2 and logits.shape[1] == 1:
            # (N,1)
            prob = torch.sigmoid(logits[:, 0])
            return (prob >= 0.5).long()
        # (N,K)
        return torch.argmax(logits, dim=1).long()

    @torch.no_grad()
    def model_eval(self, data, ensemble=False, round=None):
        assert self.model_ is not None, "model is not trained yet"

        dl = DataLoader(data, batch_size=self.batch_size, shuffle=False)
        self.model_.eval()

        correct = 0.0
        total = 0.0
        preds = []

        for (x, y, a) in dl:
            if self.gpu:
                x, y, a = x.to(self.device), y.to(self.device), a.to(self.device)
            logits = self.model_(x)
            pred_cls = self._predict_classes_from_logits(logits)
            preds.append(pred_cls.detach().cpu())
            correct += pred_cls.eq(y.view(-1).long()).sum().item()
            total += y.size(0)

        acc = float(correct / max(total, 1.0))
        pred_class = torch.cat(preds, dim=0).numpy()

        diff, matrix = self.fair_evaluate(
            Y=data.Y.ravel(),
            pred_Y=pred_class.ravel(),
            A=data.A.ravel()
        )
        return acc, diff, matrix

    def evaluate(self, split="test"):
        if split == "train":
            data = self.train_data
        elif split == "val":
            data = self.val_data
        else:
            data = self.test_data
        return self.model_eval(data)