import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from sw2 import Wasserstein_Distance
from scipy.stats import pearsonr
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
import numpy as np
from sw2 import (
    Wasserstein_Distance,
    Sliced_Wasserstein_Distance,
    Projected_Wasserstein_Distance,
    Energy_based_Sliced_Wasserstein,
    Max_Sliced_Wasserstein_Distance,
    Min_SWGG,
    Expected_Sliced_Transport,
)
from utils import optimal_alpha_general, optimal_alpha_simplex




class Wormhole(nn.Module):
    def __init__(self, transformer, config, run_dir=None,
                 compute_stats=True,
                 save_best=True):
        super().__init__()
        self.encoder = transformer.encoder
        self.decoder = transformer.decoder
        self.config = config
        self.device = config.device
        self.coeff_dec = config.coeff_dec

        # options
        self.compute_stats = bool(compute_stats)
        self.save_best = bool(save_best)

        # histories
        self.loss_history = []
        self.enc_loss_history = []
        self.dec_loss_history = []
        self.corr_history = []
        self.r2_history = []

        # best tracking
        self.best_r2 = float('-inf')
        self.best_weight = None
        self.best_epoch = -1

        # run dir
        self.run_dir = run_dir
        if self.run_dir is not None:
            os.makedirs(self.run_dir, exist_ok=True)

        if self.compute_stats:
            from scipy.stats import pearsonr as _pearsonr  # lazy import
            from sklearn.metrics import r2_score as _r2_score
            self._pearsonr = _pearsonr
            self._r2_score = _r2_score
        else:
            self._pearsonr = None
            self._r2_score = None

    def forward(self, x, weights=None):
        latent = self.encoder(x, weights)
        recon_x = self.decoder(latent)
        return recon_x, latent

    def loss(self, batch_pc):
        num_pc = batch_pc.size(0)
        enc_pc = self.encoder(batch_pc)
        dec_pc = self.decoder(enc_pc)

        pw_pc = [Wasserstein_Distance(batch_pc[i], batch_pc[j], device=self.device)
                 for i in range(num_pc) for j in range(i+1, num_pc)]
        pw_enc = [torch.mean((enc_pc[i] - enc_pc[j])**2)
                  for i in range(num_pc) for j in range(i+1, num_pc)]
        pw_dec = [Wasserstein_Distance(batch_pc[i], dec_pc[i], device=self.device)
                  for i in range(num_pc)]

        pw_pc = torch.stack(pw_pc)
        pw_enc = torch.stack(pw_enc)
        pw_dec = torch.stack(pw_dec)

        enc_loss = torch.mean((pw_pc - pw_enc)**2)
        dec_loss = torch.mean(pw_dec)
        total_loss = enc_loss + self.coeff_dec * dec_loss
        return total_loss, enc_loss, dec_loss

    def _compute_stats_on_batch(self, batch_pc):
        assert self.compute_stats, "compute_stats=False, cannot run"
        num_pc = batch_pc.size(0)
        enc_pc = self.encoder(batch_pc)

        pw_pc = [Wasserstein_Distance(batch_pc[i], batch_pc[j], device=self.device)
                 for i in range(num_pc) for j in range(i+1, num_pc)]
        pw_enc = [torch.mean((enc_pc[i] - enc_pc[j])**2)
                  for i in range(num_pc) for j in range(i+1, num_pc)]

        pw_pc = torch.stack(pw_pc).detach().cpu().numpy()
        pw_enc = torch.stack(pw_enc).detach().cpu().numpy()

        corr, _ = self._pearsonr(pw_pc, pw_enc)
        r2 = self._r2_score(pw_pc, pw_enc)
        return float(corr), float(r2)

    def _save_checkpoint(self, optimizer, scheduler, epoch, fname=None):
        if self.run_dir is None:
            return
        fname = fname or f"checkpoint_epoch{epoch}.pth"
        path = os.path.join(self.run_dir, fname)
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'epoch': epoch
        }, path)
        print(f"Saved checkpoint: {path}")

    def train_model(self, dataloader, optimizer, scheduler=None, epochs=10, verbose=True,
                    save_every=100):
        self.train()
        for epoch in range(epochs):
            last_total = last_enc = last_dec = None
            last_batch_pc = None

            for batch_idx, batch in enumerate(dataloader):
                batch_pc = batch[0].to(self.device)
                last_batch_pc = batch_pc

                optimizer.zero_grad()
                total_loss, enc_loss, dec_loss = self.loss(batch_pc)
                total_loss.backward()
                optimizer.step()

                last_total = total_loss.item()
                last_enc = enc_loss.item()
                last_dec = dec_loss.item()

            self.loss_history.append(last_total)
            self.enc_loss_history.append(last_enc)
            self.dec_loss_history.append(last_dec)

            if self.compute_stats and last_batch_pc is not None:
                corr, r2 = self._compute_stats_on_batch(last_batch_pc)
                self.corr_history.append(corr)
                self.r2_history.append(r2)

                if self.save_best and (r2 > self.best_r2):
                    self.best_r2 = r2
                    self.best_weight = {k: v.detach().cpu().clone()
                                        for k, v in self.state_dict().items()}
                    self.best_epoch = epoch + 1

                if verbose:
                    print(
                        f"Epoch {epoch+1}/{epochs} | "
                        f"Total: {last_total:.4f} | Enc: {last_enc:.4f} | Dec: {last_dec:.4f} | "
                        f"R²(last): {r2:.4f} | Corr(last): {corr:.4f} | "
                        f"Best R²: {self.best_r2:.4f} (ep {self.best_epoch})"
                    )
            else:
                if verbose:
                    print(
                        f"Epoch {epoch+1}/{epochs} | "
                        f"Total: {last_total:.4f} | Enc: {last_enc:.4f} | Dec: {last_dec:.4f}"
                    )

            if scheduler is not None:
                scheduler.step()

            # save theo chu kỳ
            if (save_every is not None) and ((epoch + 1) % save_every == 0):
                self._save_checkpoint(optimizer, scheduler, epoch + 1)

        if self.save_best and (self.run_dir is not None) and (self.best_weight is not None):
            best_path = os.path.join(self.run_dir, "best_model.pth")
            torch.save({
                'model_state_dict': self.best_weight,
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                'epoch': self.best_epoch
            }, best_path)
            print(f"Saved best model: {best_path}")

        self.plot_history()

    def plot_history(self):
        if self.run_dir is None:
            return
        if len(self.loss_history) == 0:
            print("No history to plot."); return
        import matplotlib.pyplot as plt
        rows = 2 if self.compute_stats else 1
        fig, axs = plt.subplots(rows, 2 if rows == 2 else 3, figsize=(12, 6 if rows==1 else 8))

        if rows == 1:
            ax0, ax1, ax2 = axs
            ax0.plot(self.loss_history); ax0.set_title('Total Loss (per epoch)')
            ax1.plot(self.enc_loss_history); ax1.set_title('Encoder Loss (per epoch)')
            ax2.plot(self.dec_loss_history); ax2.set_title('Decoder Loss (per epoch)')
        else:
            axs[0,0].plot(self.loss_history); axs[0,0].set_title('Total Loss (per epoch)')
            axs[0,1].plot(self.enc_loss_history); axs[0,1].set_title('Encoder Loss (per epoch)')
            axs[1,0].plot(self.dec_loss_history); axs[1,0].set_title('Decoder Loss (per epoch)')
            axs[1,1].plot(self.r2_history); axs[1,1].set_title('R² (last batch per epoch)')

        plt.tight_layout()
        fig_path = os.path.join(self.run_dir, 'loss_curve.png')
        plt.savefig(fig_path)
        print(f"Saved loss curve: {fig_path}")

    def compute_wormhole(self, batch_pc1, batch_pc2):
        enc_pc1 = self.encoder(batch_pc1)
        enc_pc2 = self.encoder(batch_pc2)
        return torch.mean(torch.pow(enc_pc1 - enc_pc2, 2), dim=-1)





class fast_Wormhole(Wormhole):
    def __init__(
        self,
        estimate_alpha_general,
        transformer,
        config,
        lower_func=None,            # backward-compat optional
        upper_func=None,            # backward-compat optional
        ground_truth_func=None,
        run_dir=None,
        metric_funcs=None,          # list[callable(x_pc, y_pc) -> scalar]
        ridge_lambda: float = 0.0,  # Ridge cho ước lượng alpha (0 -> OLS)
        metric_names=None,          # optional: tên hiển thị cho metrics
        compute_stats=True,
        save_best=True,
    ):
        # truyền cờ xuống Wormhole
        super().__init__(transformer, config, run_dir,
                         compute_stats=compute_stats,
                         save_best=save_best)

        # ===== cấu hình metric list =====
        if metric_funcs is None:
            if (lower_func is None) or (upper_func is None):
                raise ValueError(
                    "Provide either metric_funcs=[...] or both lower_func and upper_func."
                )
            metric_funcs = [lower_func, upper_func]
        
        self.estimate_alpha_general = estimate_alpha_general

        self.metric_funcs = list(metric_funcs)
        self.metric_names = (
            list(metric_names)
            if metric_names is not None
            else [getattr(f, "__name__", f"metric_{i}") for i, f in enumerate(self.metric_funcs)]
        )

        if ground_truth_func is None:
            raise ValueError("ground_truth_func is required.")
        self.ground_truth_func = ground_truth_func

        # vector alpha (khởi tạo đều nhau cho đẹp)
        self.alphas = torch.zeros(len(self.metric_funcs), device=self.device, dtype=torch.float32)
        if len(self.metric_funcs) > 0:
            self.alphas.fill_(1.0 / len(self.metric_funcs))

        self.ridge_lambda = float(ridge_lambda)
        self.approx_gt = None

    @staticmethod
    def _to_float(x):
        import torch as _torch
        if isinstance(x, _torch.Tensor):
            return float(x.detach().cpu().item())
        if isinstance(x, (tuple, list)):
            return float(x[0])
        return float(x)

    def estimate_alpha(self, samples):
        """
        Ước lượng vector alpha cho tổ hợp tuyến tính của self.metric_funcs
        sao cho xấp xỉ ground-truth tốt nhất theo OLS/Ridge:
            a* = argmin || y - X a ||^2 (+ λ||a||^2 nếu ridge>0)
        - samples: iterable các point clouds (Tensor), dùng toàn bộ cặp (i<j).
        """
        device = self.device
        n = len(samples)
        train_pairs = [(i, j) for i in range(n) for j in range(i + 1, n)]
        print(f"Computing all pairwise metrics for train ({len(train_pairs)} pairs) on {len(self.metric_funcs)} metrics...")

        X_rows = []
        y_vals = []

        for i, j in train_pairs:
            x_pc = samples[i].to(device)
            y_pc = samples[j].to(device)

            # ground-truth
            gt_val = self._to_float(self.ground_truth_func(x_pc, y_pc))
            y_vals.append(gt_val)

            # các metric trong list
            feats = [self._to_float(f(x_pc, y_pc)) for f in self.metric_funcs]
            X_rows.append(feats)

        X = np.asarray(X_rows, dtype=np.float64)   # shape: (num_pairs, d)
        y = np.asarray(y_vals, dtype=np.float64)   # shape: (num_pairs,)

        if self.estimate_alpha_general:
            a = optimal_alpha_general(X, y, ridge=self.ridge_lambda)
        else:
            a = optimal_alpha_simplex(X, y, ridge=self.ridge_lambda)

        self.alphas = torch.tensor(a, device=device, dtype=torch.float32)

        def approx_fn(x, y):
            vals = []
            for f in self.metric_funcs:
                v = f(x, y)
                vals.append(v)
            a_t = self.alphas  # (d,)
            feats = torch.stack(vals)  # (d,)
            return torch.dot(a_t, feats)

        self.approx_gt = approx_fn
        # In thông tin hệ số
        coef_str = ", ".join([f"{n}={float(v):.4f}" for n, v in zip(self.metric_names, self.alphas.tolist())])
        print(f"[estimate_alpha] alphas: {coef_str}  | ridge={self.ridge_lambda}")

    def loss(self, batch_pc):
        if self.approx_gt is None:
            raise RuntimeError("approx_gt is None. Call estimate_alpha(samples) before training.")

        num_pc = batch_pc.size(0)
        enc_pc = self.encoder(batch_pc)
        dec_pc = self.decoder(enc_pc)

        pw_pc = [self.approx_gt(batch_pc[i], batch_pc[j])
                for i in range(num_pc) for j in range(i + 1, num_pc)]
        pw_enc = [torch.mean((enc_pc[i] - enc_pc[j]) ** 2)
                  for i in range(num_pc) for j in range(i + 1, num_pc)]
        pw_dec = [self.approx_gt(batch_pc[i], dec_pc[i])
                  for i in range(num_pc)]
        # ground_truth_func
        pw_pc = torch.stack(pw_pc)
        pw_enc = torch.stack(pw_enc)
        pw_dec = torch.stack(pw_dec)

        enc_loss = torch.mean((pw_pc - pw_enc) ** 2)
        dec_loss = torch.mean(pw_dec)
        total_loss = enc_loss + self.coeff_dec * dec_loss

        return total_loss, enc_loss, dec_loss

    def _compute_stats_on_batch(self, batch_pc):
        if self.approx_gt is None:
            raise RuntimeError("approx_gt is None. Call estimate_alpha(samples) before training.")

        num_pc = batch_pc.size(0)
        enc_pc = self.encoder(batch_pc)

        pw_pc = [self.approx_gt(batch_pc[i], batch_pc[j])
                 for i in range(num_pc) for j in range(i + 1, num_pc)]
        pw_enc = [torch.mean((enc_pc[i] - enc_pc[j]) ** 2)
                  for i in range(num_pc) for j in range(i + 1, num_pc)]

        pw_pc = torch.stack(pw_pc).detach().cpu().numpy()
        pw_enc = torch.stack(pw_enc).detach().cpu().numpy()

        corr, _ = pearsonr(pw_pc, pw_enc)
        r2 = r2_score(pw_pc, pw_enc)
        return float(corr), float(r2)