import os
import copy
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm


from misc import (
    batchify,
    env_batchify,
    HSICLoss,
    DiscreteConditionalExpecationTest,
    SampleCovariance,
)

class BAGTrainer:
    def __init__(self, model, loss_fn, reg_lambda, config, causal_dir=True):

        self.model = copy.deepcopy(model)
        self.config = config
        self.causal_dir = causal_dir
        self.classification = bool(config.classification)
        self.num_class = int(config.num_class)
        self.z_dim = int(config.z_dim)

        self.criterion = loss_fn
        self.reg_lambda = float(reg_lambda)
        self.vae_lambda = float(getattr(config, "vae_lambda", 1.0))


        self.outer_optimizer = torch.optim.Adam(
            self.model.get_parameters(base_lr=float(config.lr))
        )

        self.eta_test_ind = 0
        self.test_tta_result_list = []
        self.model_path = os.path.join(config.model_save_dir, "adap_invar.tar")
        self.emb_path = os.path.join(config.model_save_dir, "adap_invar_emb")
        if getattr(config, "save_test_phi", False):
            os.makedirs(self.emb_path, exist_ok=True)
    

    def reg_loss(self, f_beta, f_eta, y, env_ind):

        if self.classification:
            if self.causal_dir:
                return HSICLoss(f_beta, f_eta)
            else:
                return torch.sum(torch.abs(DiscreteConditionalExpecationTest(f_beta, f_eta, y)))
        else:
            if self.causal_dir:
                return SampleCovariance(f_beta, f_eta)[0][0]
            else:
                return torch.sum(torch.abs(DiscreteConditionalExpecationTest(f_beta, f_eta, y)))


    def train(self, train_dataset, batch_size, test_dataset,
              n_outer_loop=100, n_inner_loop=30, log_dir="./log", verbose=True):


        os.makedirs(log_dir, exist_ok=True)
        log_path = os.path.join(log_dir, "epoch_metrics.txt")
        if not os.path.exists(log_path):
            with open(log_path, "w") as f:
                f.write("Epoch\tacc_c_test\tacc_s_test\tacc_full_test\tacc_c_tta\tacc_s_tta\tacc_full_tta\n")

        device = next(self.model.parameters()).device
        epochs = int(self.config.epochs)
        n_train_envs = len(train_dataset)


        def _steps_per_epoch():
            return max(1, min(len(ds) // batch_size for ds in train_dataset))

        # VAE 先验
        normal_distribution = torch.distributions.MultivariateNormal(
            torch.zeros(self.z_dim, device=device),
            torch.eye(self.z_dim, device=device)
        )

        self.model.train()

        for t in range(epochs):
            steps = _steps_per_epoch()
            train_batches = env_batchify(train_dataset, batch_size, self.config)

            bar = tqdm(
                train_batches,
                total=steps,
                desc=f"Train [{t+1}/{epochs}]",
                dynamic_ncols=True,
                leave=False
            )

            for step, train in enumerate(bar):
                phi_loss = 0.0
                c_correct = 0
                s_correct = 0
                full_correct = 0
                total = 0


                for env_ind in range(n_train_envs):
                    x, y = train[env_ind]

                    (f_beta, f_eta, y_dist, p_y_given, rep, rep_hat,
                     mu, log_var, latent_z) = self.model(x, env_ind)


                    pY_loss = F.cross_entropy(f_beta + f_eta - y_dist, y)


                    reg_term = self.reg_loss(f_beta, f_eta, y, env_ind) * self.reg_lambda


                    recon = F.mse_loss(rep, rep_hat, reduction="sum") / len(rep)
                    q_dist = torch.distributions.Normal(mu, torch.exp(torch.clamp(log_var, min=-10) / 2))
                    kl = (q_dist.log_prob(latent_z).sum(dim=1) - normal_distribution.log_prob(latent_z)).mean()
                    C = torch.clamp(
                        torch.tensor(getattr(self.config, "C_max", 5.0), device=device)
                        * (t + 1) / max(1, int(getattr(self.config, "C_stop_iter", 10))),
                        0, getattr(self.config, "C_max", 5.0)
                    )
                    vae_term = (recon + getattr(self.config, "beta", 1.0) * (kl - C).abs()) * self.vae_lambda

                    phi_loss = phi_loss + pY_loss + reg_term + vae_term
                    phi_loss = torch.clamp(phi_loss, max=50)


                    if self.classification:
                        c_pred   = f_beta.argmax(1)             
                        s_pred   = f_eta.argmax(1)             
                        full_pred= p_y_given.argmax(1)          

                        c_correct   += (c_pred   == y).sum().item()
                        s_correct   += (s_pred   == y).sum().item()
                        full_correct+= (full_pred== y).sum().item()
                        total       += y.size(0)


                self.outer_optimizer.zero_grad()
                phi_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
                self.outer_optimizer.step()


                if self.classification and total > 0:
                    bar.set_postfix({
                        "loss": f"{(phi_loss.item() / max(1, n_train_envs)):.4f}",
                        "acc":  f"{full_correct/total:.4f}",  
                        "acc_c":f"{c_correct/total:.4f}",
                        "acc_s":f"{s_correct/total:.4f}",
                    })


                if step + 1 >= steps:
                    break


            tqdm.write(f"Epoch {t+1}/{epochs} finished.")


            confusion_matrix = self.train_only_z_c_confusion_matrix(train_dataset, batch_size)

            current_model = copy.deepcopy(self.model)


            acc_c_test, acc_s_test, acc_full_test = self.test(
                test_dataset, rep_learning_flag=False,
                input_model=current_model, batch_size=batch_size, print_flag=True
            )


            acc_c_tta, acc_s_tta, acc_full_tta = self.test_with_TTA(
                test_dataset, confusion_matrix, rep_learning_flag=False,
                input_model=current_model, batch_size=batch_size, print_flag=True
            )


            self.test_tta_result_list.append(acc_full_tta)

            with open(log_path, "a") as f:
                f.write(f"{t}\t{acc_c_test:.4f}\t{acc_s_test:.4f}\t{acc_full_test:.4f}\t"
                        f"{acc_c_tta:.4f}\t{acc_s_tta:.4f}\t{acc_full_tta:.4f}\n")



        return max(self.test_tta_result_list) if self.test_tta_result_list else 0.0


    def test(self, test_dataset, rep_learning_flag=False, input_model=None,
             batch_size=128, print_flag=True):

        test_model = self.model if input_model is None else input_model
        test_model.eval()

        total = 0
        c_correct = 0
        s_correct = 0
        full_correct = 0


        est_steps = max(1, len(test_dataset) // batch_size)
        it = batchify(test_dataset, batch_size, self.config)

        for x, y in tqdm(it, total=est_steps, desc="Eval", dynamic_ncols=True, leave=False):
            f_beta, f_eta, y_dist, p_y_given, rep, rep_hat, mu, log_var, latent_z = \
                test_model(x, self.eta_test_ind, rep_learning=rep_learning_flag)

            if self.classification:
                c_pred    = f_beta.argmax(1)
                s_pred    = f_eta.argmax(1)
                full_pred = p_y_given.argmax(1)

                c_correct   += (c_pred    == y).sum().item()
                s_correct   += (s_pred    == y).sum().item()
                full_correct+= (full_pred == y).sum().item()
                total       += y.size(0)
            else:
                total += y.size(0)

        if total == 0:
            return 0.0, 0.0, 0.0

        acc_c    = c_correct   / total
        acc_s    = s_correct   / total
        acc_full = full_correct/ total

        if print_flag:
            print(f"[No-TTA] Acc_c(prediction_c) = {acc_c:.4f}")
            print(f"[No-TTA] Acc_s(prediction_s) = {acc_s:.4f}")
            print(f"[No-TTA] Acc_full(p_Y_given_zc_zs_E) = {acc_full:.4f}")

        return acc_c, acc_s, acc_full


    def test_with_TTA(self, test_dataset, confusion_matrix, rep_learning_flag=False,
                      input_model=None, batch_size=128, print_flag=True):

        student = self.model if input_model is None else input_model
        device = next(student.parameters()).device
        criterion = nn.CrossEntropyLoss()


        teacher = copy.deepcopy(student).eval()
        for p in teacher.parameters():
            p.requires_grad = False


        optimizer = student.get_optimizer_for_specific_parameters(
            lr=float(getattr(self.config, "fine_tune_lr", 1e-3))
        )
        student.train()
        student.train_only_z_s_classfier()


        student.Phi.eval()
        if hasattr(student, "encoder"): student.encoder.eval()
        if hasattr(student, "decoder"): student.decoder.eval()
        if hasattr(student, "fc_mu"): student.fc_mu.eval()
        if hasattr(student, "fc_logvar"): student.fc_logvar.eval()
        if hasattr(student, "label_predict_with_z_c"): student.label_predict_with_z_c.eval()


        tta_epochs = int(getattr(self.config, "train_z_s_epoch", 3))
        for e in range(tta_epochs):
            est_steps = max(1, len(test_dataset) // batch_size)
            it = batchify(test_dataset, batch_size, self.config)
            bar = tqdm(it, total=est_steps,
                       desc=f"TTA-[{e+1}/{tta_epochs}]",
                       dynamic_ncols=True, leave=False)

            total, correct, loss_sum = 0, 0, 0.0

            for x, y in bar:
                with torch.no_grad():
                    f_beta_t, _, _, _, _, _, _, _, _ = teacher(
                        x, self.eta_test_ind, rep_learning=rep_learning_flag, test_tta=True)
                    pseudo = f_beta_t.argmax(dim=1)

                _, f_eta, _, _, _, _, _, _, _ = student(
                    x, self.eta_test_ind, rep_learning=rep_learning_flag, test_tta=True)

                loss = criterion(f_eta, pseudo)
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=5.0)
                optimizer.step()

                loss_sum += float(loss.item())
                pred = f_eta.argmax(1)
                correct += (pred == y).sum().item()
                total += y.size(0)

                bar.set_postfix({"loss": f"{loss_sum/max(1,total):.6f}",
                                 "acc": f"{correct/max(1,total):.4f}"})


        student.eval()
        total = 0
        c_correct = 0
        s_correct = 0
        full_correct = 0

        est_steps = max(1, len(test_dataset) // batch_size)
        it = batchify(test_dataset, batch_size, self.config)

        for x, y in tqdm(it, total=est_steps, desc="Eval-TTA", dynamic_ncols=True, leave=False):
            f_beta, f_eta, y_dist, p_y_given, rep, rep_hat, mu, log_var, z = \
                student(x, self.eta_test_ind, rep_learning=rep_learning_flag, confusion_matrix=confusion_matrix)

            c_pred    = f_beta.argmax(1)
            s_pred    = f_eta.argmax(1)
            full_pred = p_y_given.argmax(1)

            c_correct   += (c_pred    == y).sum().item()
            s_correct   += (s_pred    == y).sum().item()
            full_correct+= (full_pred == y).sum().item()
            total       += y.size(0)

        acc_c_tta    = c_correct   / max(1, total)
        acc_s_tta    = s_correct   / max(1, total)
        acc_full_tta = full_correct/ max(1, total)

        if print_flag and total > 0:
            print(f"[TTA] Acc_c(prediction_c) = {acc_c_tta:.4f}")
            print(f"[TTA] Acc_s(prediction_s) = {acc_s_tta:.4f}")
            print(f"[TTA] Acc_full(p_Y_given_zc_zs_E) = {acc_full_tta:.4f}")

        return acc_c_tta, acc_s_tta, acc_full_tta

    def save_model(self):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.outer_optimizer.state_dict(),
        }, self.model_path)

    def train_only_z_c_confusion_matrix(self, train_dataset, batch_size):
        self.model.eval()
        num_classes = self.num_class
        cf_raw = torch.zeros(num_classes, num_classes, dtype=torch.float32)

        n_envs = len(train_dataset)
        est_steps = max(1, min(len(ds) // batch_size for ds in train_dataset))
        train_batches = env_batchify(train_dataset, batch_size, self.config)

        for b_idx, train_data in tqdm(
            enumerate(train_batches), total=est_steps,
            desc="ConfMat", dynamic_ncols=True, leave=False
        ):
            for env_ind in range(n_envs):
                x, y = train_data[env_ind]
                f_beta, _, _, _, _, _, _, _, _ = self.model(x, env_ind)
                pred = f_beta.argmax(1)

                for i in range(len(y)):
                    cf_raw[y[i].item(), pred[i].item()] += 1

            if b_idx + 1 >= est_steps:
                break


        eps = 1e-6
        row_sums = cf_raw.sum(dim=1, keepdim=True) + eps
        cf = cf_raw / row_sums
        return cf



