# Standard library imports
import os
import uuid
from copy import deepcopy

# Third party library imports
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
from torchvision.models import resnet18, ResNet18_Weights
from tqdm import tqdm


class Explainer(nn.Module):
    def __init__(self, clf, device, bkgd, num_channels):
        super().__init__()
        self.model = smp.Unet(
            encoder_name="resnet18",
            encoder_weights=None,
            in_channels=3,
            classes=1,
            activation=None,
        )
        self.clf = deepcopy(clf)
        self.device = device
        self.bkgd = bkgd
        self.num_channels = num_channels
        self.h, self.w = self.bkgd.shape[-2:]
        self.num_mask_entries = self.h * self.w
        self.lap_kernel = (
            torch.tensor(
                [[0, 1, 0], [1, -4, 1], [0, 1, 0]],
                dtype=torch.float32,
                device=self.device,
            )
            .unsqueeze(0)
            .unsqueeze(0)
        )

        # Switch classifier to eval mode if needed
        if self.clf.training:
            self.clf.eval()
            print(
                "Warning: Classifier was not in evaluation mode. Switched to evaluation mode."
            )

        # Freeze weights of classifier
        j = 0
        for param in self.clf.parameters():
            j += 1
            param.requires_grad = False
        if j > 0:
            print("Froze the classifier parameters")

    def forward(self, x):
        x = self.model(x)
        return x

    def compute_losses(self, f_xS, f_xSc, f_x, y_hat, y_0, alpha, masks):
        # Accuracies for validation
        err_S = 1 - torch.sum((f_xS >= 0.5) * 1.0 == y_hat) / len(y_hat)
        err_Sc = 1 - torch.sum((f_xSc >= 0.5) * 1.0 == y_0) / len(y_hat)

        # Sufficiency and necessity terms
        # suff = F.binary_cross_entropy(f_xS, y_hat, reduction="mean")
        # necc = F.binary_cross_entropy(f_xSc, y_0, reduction="mean")
        suff = torch.mean(torch.abs(f_xS - y_hat))
        necc = torch.mean(torch.abs(f_xSc - y_0))
        obj_loss = alpha * suff + (1 - alpha) * necc

        # Sparsity loss
        sp_loss = torch.mean(masks)

        # TV norm loss for smoothness
        diff_h = torch.abs(masks[:, 1:, :] - masks[:, :-1, :])
        diff_w = torch.abs(masks[:, :, 1:] - masks[:, :, :-1])
        sm_loss = torch.mean(diff_h) + torch.mean(diff_w)
        """
        diff_h = torch.abs(masks[:, :, 1:, :] - masks[:, :, :-1, :])
        diff_w = torch.abs(masks[:, :, :, 1:] - masks[:, :, :, :-1])
        sm_loss = torch.mean(diff_h) + torch.mean(diff_w)
        """

        # Regularity loss (laplacian)
        # masks_reshape = masks.view(masks.shape[0] * masks.shape[1], 1, self.h, self.w)
        masks_reshape = masks.unsqueeze(1)

        # Apply convolution to compute Laplacians for all masks
        laplacians = F.conv2d(masks_reshape, self.lap_kernel, padding=1)
        laplacians = laplacians**2

        # Reshape laplacians back to [B, N, h, w]
        # laplacians = laplacians.view(masks.shape[0], masks.shape[1], self.h, self.w)
        laplacians = laplacians.view(masks.shape[0], self.h, self.w)
        average_laplacians_per_image = laplacians

        # Take the average of Laplacians across N masks for each image (along dimension 1)
        # average_laplacians_per_image = torch.mean(laplacians, dim=1)  # Shape: [B, h, w]

        # Take the mean of the Laplacians for each image (spatial mean)
        spatial_mean_laplacians = torch.mean(
            average_laplacians_per_image, dim=[1, 2]
        )  # Shape: [B]

        # Finally, take the average across all images in the batch
        sh_loss = torch.mean(spatial_mean_laplacians)

        return obj_loss, sp_loss, sm_loss, sh_loss, err_S, err_Sc

    def sample(self, x, masks, bkgd, num_samples):
        # Flatten and reshape tensors
        bkgd_flat = bkgd.view(-1, self.num_mask_entries).to(self.device)
        x_flat = x.view(x.shape[0], x.shape[1], self.num_mask_entries)  # [B, C, d^2]
        masks_flat = masks.view(x.shape[0], -1)  # [B, d^2]
        masks_reshape = masks_flat.unsqueeze(1).expand(-1, num_samples, -1)  # [B, N, d^2]

        # Make logits and sample
        logits_S = torch.stack(
            (masks_reshape, -masks_reshape), dim=-1
        )  # [B, N, d^2, 2]
        mask_01_S_flat = F.gumbel_softmax(logits_S, tau=1, hard=True)[
            :, :, :, 0
        ]  # [B, N, d^2, 1]
        """
        mask_S_flat = F.gumbel_softmax(logits_S, tau=1, hard=False)[
            :, :, :, 0
        ]  # [B, N, d^2, 1]
        """
        x_S_flat = (
            mask_01_S_flat.unsqueeze(2) * x_flat.unsqueeze(1)
            + (1 - mask_01_S_flat.unsqueeze(2)) * bkgd_flat
        )
        x_Sc_flat = (1 - mask_01_S_flat.unsqueeze(2)) * x_flat.unsqueeze(
            1
        ) + mask_01_S_flat.unsqueeze(2) * bkgd_flat

        # Reshape
        mask_01 = mask_01_S_flat.view(-1, num_samples, self.h, self.w)  # [B, N, h, w]
        # mask_S = mask_S_flat.view(-1, num_samples, self.h, self.w)  # [B, N, h, w]
        x_S = x_S_flat.view(x.shape[0], num_samples, x.shape[1], self.h, self.w)
        x_Sc = x_Sc_flat.view(x.shape[0], num_samples, x.shape[1], self.h, self.w)
        return mask_01, x_S, x_Sc

    def compute_mask_preds(self, x_S, x_Sc):
        # f(xS)
        f_xS = self.clf(
            x_S.view(x_S.shape[0] * x_S.shape[1], self.num_channels, self.h, self.w)
        )  # [B*N, 1]
        f_xS = f_xS.view(x_S.shape[0], x_S.shape[1], 1).mean(dim=1).squeeze(-1)

        # f(xSc)
        f_xSc = self.clf(
            x_Sc.view(x_Sc.shape[0] * x_Sc.shape[1], self.num_channels, self.h, self.w)
        )  # [B*N, 1]
        f_xSc = f_xSc.view(x_Sc.shape[0], x_Sc.shape[1], 1).mean(dim=1).squeeze(-1)

        return f_xS, f_xSc

    def train_model(
        self,
        alpha,
        _y_0,
        sp_mult,
        sm_mult,
        sh_mult,
        dataloaders,
        optimizer,
        num_samples,
        num_epochs,
        save_path,
        model_name,
        log_step=50,
        log=False,
    ):
        # Put explainer and background on device
        self.to(self.device)
        bkgd = self.bkgd.to(self.device)

        # Wandb logging
        if log == True:
            run_id = uuid.uuid4().hex[:4]
            wandb.init(
                project="celebahq_explainer_final_training",
                entity="beepulbharti",
                name=f"explainer-{run_id}",
            )

        # Training loop
        ops = ["train", "val"]
        best_explainer_loss = 100
        for epoch in range(num_epochs):
            for op in ops:
                if op == "train":
                    torch.set_grad_enabled(True)
                    self.train()
                    self.clf.eval()
                elif op == "val":
                    torch.set_grad_enabled(False)
                    self.eval()

                dataloader = dataloaders[op]
                running_obj_loss = 0.0
                running_err_S = 0.0
                running_err_Sc = 0.0
                running_sp_loss = 0.0
                running_sm_loss = 0.0
                running_sh_loss = 0.0
                running_total_loss = 0.0

                for i, data in enumerate(tqdm(dataloader)):
                    # Load images
                    x, _ = data
                    x = x.to(self.device)

                    # Compute original predictions
                    f_x = (self.clf(x)).squeeze(-1).detach()
                    y_hat = (f_x >= 0.5) * 1.0
                    y_0 = torch.full_like(y_hat, _y_0, dtype=torch.float)

                    # Compute masks via explainer
                    masks_scores = self(x).squeeze(1)  # [B, h, w]
                    masks = torch.sigmoid(masks_scores)

                    # Sample masks and new samples using gumbel
                    masks_01, x_S, x_Sc = self.sample(x, masks_scores, bkgd, num_samples)

                    # Evaluate f_x, f_xS, and f_xSc
                    f_xS, f_xSc = self.compute_mask_preds(x_S, x_Sc)

                    # Evaluate losses
                    obj_loss, sp_loss, sm_loss, sh_loss, err_S, err_Sc = self.compute_losses(
                        f_xS, f_xSc, f_x, y_hat, y_0, alpha, masks
                    )
                    total_loss = obj_loss + sp_mult * sp_loss + sm_mult * sm_loss + sh_mult * sh_loss

                    # Add losses
                    running_obj_loss += obj_loss.item()
                    running_err_S += err_S.item()
                    running_err_Sc += err_Sc.item()
                    running_sp_loss += sp_loss.item()
                    running_sm_loss += sm_loss.item()
                    running_sh_loss += sh_loss.item()
                    running_total_loss += total_loss.item()

                    # Optimize and log
                    if op == "train":
                        optimizer.zero_grad()
                        total_loss.backward()
                        optimizer.step()

                        if (i + 1) % log_step == 0:
                            results_dict = {
                                "train_total_loss": running_total_loss / log_step,
                                "train_obj_loss": running_obj_loss / log_step,
                                "train_sp_loss": running_sp_loss / log_step,
                                "train_sm_loss": running_sm_loss / log_step,
                                "train_sh_loss": running_sh_loss / log_step,

                            }
                            for key, value in results_dict.items():
                                print(f"{key}: {value}")

                            if log == True:
                                wandb.log(results_dict)

                            running_obj_loss = 0
                            running_err_S = 0
                            running_err_Sc = 0
                            running_sp_loss = 0
                            running_sm_loss = 0
                            running_sh_loss = 0
                            running_total_loss = 0
                

                # Validation
                if op == "val":
                    results_dict = {
                        "val_total_loss": running_total_loss / len(dataloader),
                        "val_obj_loss": running_obj_loss / len(dataloader),
                        "val_err_S": running_err_S / len(dataloader),
                        "val_err_Sc": running_err_Sc / len(dataloader),
                        "val_sp_loss": running_sp_loss / len(dataloader),
                        "val_sm_loss": running_sm_loss / len(dataloader),
                        "val_sh_loss": running_sh_loss / len(dataloader),
                        "epoch": epoch + 1,
                    }

                    for key, value in results_dict.items():
                        if key != "epoch":
                            print(f"{key}, {value}")

                    print(f"Epoch {epoch + 1} completed")

                    if log == True:
                        wandb.log(results_dict)

                    stopping_loss = (
                        alpha * results_dict["val_err_S"]
                        - (1 - alpha) * results_dict["val_err_Sc"]
                        + results_dict["val_sp_loss"]
                        + results_dict["val_sm_loss"]
                        + results_dict["val_sh_loss"]
                    )
                    
                    if stopping_loss < best_explainer_loss:
                        best_explainer_loss = stopping_loss
                    
                        if save_path is not None:
                            print("Saving new model")
                            torch.save(
                                self.state_dict(),
                                os.path.join(save_path, model_name),
                            )

    def train_sweep(
        self, dataloaders, optimizer, num_samples, num_epochs, config, log_step=50
    ):
        # Put explainer on device
        self.to(self.device)

        # Make area constraint vector and put on device
        tau = int(self.area * self.num_mask_entries)
        ones_tau = torch.zeros(self.num_mask_entries, device=self.device)
        ones_tau[:tau] = 1
        self.ones_tau = ones_tau

        # Flatten background
        bkgd_flat = self.bkgd.view(-1, self.num_mask_entries).to(self.device)

        # Training loop
        ops = ["train", "val"]
        best_stopping_loss = 1000

        num_epochs = 10
        for epoch in range(num_epochs):
            for op in ops:
                if op == "train":
                    torch.set_grad_enabled(True)
                    self.train()
                    self.clf.eval()
                elif op == "val":
                    torch.set_grad_enabled(False)
                    self.eval()

                dataloader = dataloaders[op]
                running_obj_loss = 0.0
                running_err_S = 0.0
                running_err_Sc = 0.0
                running_l1_loss = 0.0
                running_tv_loss = 0.0
                running_total_loss = 0.0

                for i, data in enumerate(tqdm(dataloader)):
                    # Load images
                    x, _ = data
                    x = x.to(self.device)

                    # Compute original predictions
                    f_x = (self.clf(x)).squeeze(-1).detach()
                    y_hat = (f_x >= 0.5) * 1.0

                    # Compute masks via explainer
                    masks_flat = self(x).view(x.shape[0], -1)  # [B, d^2]

                    # Sample masks using gumbel softmax
                    x_flat = x.view(
                        x.shape[0], x.shape[1], self.num_mask_entries
                    )  # [B, 3, d^2]
                    (
                        masks_S_flat,
                        masks_Sc_flat,
                        x_S_flat,
                        x_Sc_flat,
                    ) = self.sample_masks(x_flat, masks_flat, bkgd_flat, num_samples)
                    masks_S = masks_S_flat.view(-1, num_samples, self.h, self.w)

                    # Evaluate f_x, f_xS, and f_xSc
                    x_S = x_S_flat.view(
                        x.shape[0], num_samples, x.shape[1], self.h, self.w
                    )
                    x_Sc = x_Sc_flat.view(
                        x.shape[0], num_samples, x.shape[1], self.h, self.w
                    )
                    f_xS, f_xSc = self.compute_mask_preds(x, x_S, x_Sc, num_samples)

                    # Evaluate losses
                    obj_loss, l1_loss, tv_loss, err_S, err_Sc = self.compute_losses(
                        f_xS,
                        f_xSc,
                        y_hat,
                        config.alpha,
                        masks_S,
                    )
                    total_loss = (
                        obj_loss + config.l1_mult * l1_loss + config.tv_mult * tv_loss
                    )

                    running_obj_loss += obj_loss.item()
                    running_err_S += err_S.item()
                    running_err_Sc += err_Sc.item()
                    running_l1_loss += l1_loss.item()
                    running_tv_loss += tv_loss.item()
                    running_total_loss += total_loss.item()

                    if op == "train":
                        optimizer.zero_grad()
                        total_loss.backward()
                        optimizer.step()

                        if (i + 1) % log_step == 0:
                            results_dict = {
                                "train_total_loss": running_total_loss / log_step,
                                "train_obj_loss": running_obj_loss / log_step,
                                "train_l1_loss": running_l1_loss / log_step,
                                "train_tv_loss": running_tv_loss / log_step,
                            }
                            for key, value in results_dict.items():
                                print(f"{key}: {value}")

                            running_obj_loss = 0
                            running_err_S = 0
                            running_err_Sc = 0
                            running_l1_loss = 0
                            running_tv_loss = 0
                            running_total_loss = 0

                if op == "val":
                    results_dict = {
                        "val_total_loss": running_total_loss / len(dataloader),
                        "val_objective_loss": running_obj_loss / len(dataloader),
                        "val_err_S": running_err_S / len(dataloader),
                        "val_err_Sc": running_err_Sc / len(dataloader),
                        "val_l1_loss": running_l1_loss / len(dataloader),
                        "val_tv_loss": running_tv_loss / len(dataloader),
                        "epoch": epoch + 1,
                    }
                    stopping_loss = (
                        config.alpha * results_dict["val_err_S"]
                        - (1 - config.alpha) * results_dict["val_err_Sc"]
                        + results_dict["val_l1_loss"]
                        + results_dict["val_tv_loss"]
                    )
                    if stopping_loss < best_stopping_loss:
                        best_stopping_loss = stopping_loss
                    results_dict["stopping_loss"] = stopping_loss
                    results_dict["best_stopping_loss"] = stopping_loss

                    for key, value in results_dict.items():
                        if key != "epoch":
                            print(f"{key}, {value}")

        return results_dict


"""
class exp(nn.Module):
    def __init__(self, h, w, clf):
        super().__init__()
        self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT)

        # Dimensions of original images
        self.h = h
        self.w = w
        self.num_mask_entries = self.h * self.w

        # Deepcopy classifier to explain
        self.clf = deepcopy(clf)
        if self.clf.training:
            self.clf.eval()
            print(
                "Warning: Classifier was not in evaluation mode. Switched to evaluation mode."
            )

        # Freeze weights of classifier
        j = 0
        for param in self.clf.parameters():
            j += 1
            param.requires_grad = False
        if j > 0:
            print("Froze the classifier parameters")

        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, self.num_mask_entries)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.resnet(x)
        x = self.sigmoid(x)
        return x

    def compute_losses(self, f_xS, f_xSc, y_hat, alpha, masks):
        err_S = 1 - torch.sum((f_xS >= 0.5) * 1.0 == y_hat) / len(y_hat)
        err_Sc = 1 - torch.sum((f_xSc >= 0.5) * 1.0 == y_hat) / len(y_hat)

        suff = F.binary_cross_entropy(f_xS, y_hat, reduction="mean")
        necc = -F.binary_cross_entropy(f_xSc, y_hat, reduction="mean")
        obj_loss = alpha * suff + (1 - alpha) * necc

        abs_masks = torch.abs(masks)
        l1_loss = torch.mean(
            abs_masks[:, :, :, :].sum(dim=(2, 3)) / (self.num_mask_entries)
        )

        abs_h_diffs = abs(masks[:, :, 1:, :] - masks[:, :, :-1, :])
        tv_h_loss = abs_h_diffs[:, :, :, :].sum(dim=(2, 3))

        abs_w_diffs = abs(masks[:, :, :, 1:] - masks[:, :, :, :-1])
        tv_w_loss = abs_w_diffs[:, :, :, :].sum(dim=(2, 3))

        tv_loss = torch.mean((tv_w_loss + tv_h_loss) / self.num_mask_entries)

        return obj_loss, l1_loss, tv_loss, err_S, err_Sc

    def _compute_masked_x(self, x, masks, bkgd_samples):
        x_S = masks * x + (1 - masks) * bkgd_samples
        x_Sc = (1 - masks) * x + masks * bkgd_samples
        return x_S, x_Sc

    def _sample_masks(self, x, masks, bkgd, N=10):
        masks_reshape = masks.unsqueeze(1).expand(-1, N, -1)  # [B, N, d^2]
        p_S = torch.stack((masks_reshape, 1 - masks_reshape), dim=-1)  # [B, N, d^2, 2]
        p_Sc = torch.stack((1 - masks_reshape, masks_reshape), dim=-1)  # [B, N, d^2, 2]
        mask_S = F.gumbel_softmax(torch.logit(p_S, eps=0.01), tau=1, hard=True)[
            :, :, :, 0
        ]  # [B, N, d^2, 1]
        mask_Sc = F.gumbel_softmax(torch.logit(p_Sc, eps=0.01), tau=1, hard=True)[
            :, :, :, 0
        ]  # [B, N, d^2, 1]
        x_S = mask_S.unsqueeze(2) * x.unsqueeze(1) + (1 - mask_S.unsqueeze(2)) * bkgd
        x_Sc = (1 - mask_S.unsqueeze(2)) * x.unsqueeze(1) + mask_S.unsqueeze(2) * bkgd
        return mask_S, mask_Sc, x_S, x_Sc

    def train_model(
        self,
        alpha,
        l1_mult,
        tv_mult,
        dataloaders,
        optimizer,
        bkgd,
        num_epochs,
        save_path,
        model_name,
        log=False,
        device="cuda",
    ):
        
        self.to(device)
        log_step = 50
        if log == True:
            run_id = uuid.uuid4().hex[:4]
            wandb.init(
                project="CelebAHQ_v2", entity="beepulbharti", name=f"explainer-{run_id}"
            )

        ops = ["train", "val"]
        best_explainer_loss = 100
        N = 5

        for epoch in range(num_epochs):
            for op in ops:
                if op == "train":
                    torch.set_grad_enabled(True)
                    self.train()
                    self.clf.eval()
                elif op == "val":
                    torch.set_grad_enabled(False)
                    self.eval()

                dataloader = dataloaders[op]
                running_obj_loss = 0.0
                running_err_S = 0.0
                running_err_Sc = 0.0
                running_l1_loss = 0.0
                running_tv_loss = 0.0
                running_total_loss = 0.0

                for i, data in enumerate(tqdm(dataloader)):
                    # Load images
                    x, _ = data
                    x = x.to(device)

                    # Compute original predictions
                    f_x = (self.clf(x)).squeeze(-1).detach()
                    y_hat = (f_x >= 0.5) * 1.0

                    # Compute masks via explainer
                    masks_flat = self(x)  # [B, d^2]

                    # Sample masks using gumbel softmax
                    x_flat = x.view(
                        x.shape[0], x.shape[1], self.num_mask_entries
                    )  # [B, 3, d^2]
                    bkgd_flat = bkgd.view(-1, self.num_mask_entries)
                    masks_S_flat, masks_Sc_flat, x_S_flat, x_Sc_flat = (
                        self.sample_masks(x_flat, masks_flat, bkgd_flat, N)
                    )

                    masks_S = masks_S_flat.view(-1, N, self.h, self.w)
                    masks = masks_flat.view(-1, self.h, self.w).unsqueeze(1)

                    # Without gumbel sampling (z = x*m + (1-m)*b)
                    # x_S, x_Sc = self._compute_masked_x(x, masks_reshape, bkgd_samples)

                    # Evaluate f_x, f_xS, and f_xSc
                    x_S = x_S_flat.view(x.shape[0], N, x.shape[1], self.h, self.w)
                    x_Sc = x_Sc_flat.view(x.shape[0], N, x.shape[1], self.h, self.w)

                    f_xS = self.clf(
                        x_S.view(x.shape[0] * N, 3, self.h, self.w)
                    )  # [B*N, 1]
                    f_xS = f_xS.view(x.shape[0], N, 1).mean(dim=1).squeeze(-1)

                    f_xSc = self.clf(
                        x_Sc.view(x.shape[0] * N, 3, self.h, self.w)
                    )  # [B*N, 1]
                    f_xSc = f_xSc.view(x.shape[0], N, 1).mean(dim=1).squeeze(-1)

                    # Evaluate losses
                    obj_loss, l1_loss, tv_loss, err_S, err_Sc = self.compute_losses(
                        f_xS, f_xSc, y_hat, alpha, masks_S
                    )

                    total_loss = obj_loss + l1_mult * l1_loss + tv_mult * tv_loss

                    running_obj_loss += obj_loss.item()
                    running_err_S += err_S.item()
                    running_err_Sc += err_Sc.item()
                    running_l1_loss += l1_loss.item()
                    running_tv_loss += tv_loss.item()
                    running_total_loss += total_loss.item()

                    if op == "train":
                        optimizer.zero_grad()
                        total_loss.backward()
                        optimizer.step()

                        if (i + 1) % log_step == 0:
                            results_dict = {
                                "train_total_loss": running_total_loss / log_step,
                                "train_obj_loss": running_obj_loss / log_step,
                                "train_l1_loss": running_l1_loss / log_step,
                                "train_tv_loss": running_tv_loss / log_step,
                            }
                            for key, value in results_dict.items():
                                print(f"{key}: {value}")

                            if log == True:
                                wandb.log(results_dict)

                            running_obj_loss = 0
                            running_err_S = 0
                            running_err_Sc = 0
                            running_l1_loss = 0
                            running_tv_loss = 0
                            running_total_loss = 0

                if op == "val":
                    results_dict = {
                        "val_total_loss": running_total_loss / len(dataloader),
                        "val_objective_loss": running_obj_loss / len(dataloader),
                        "val_err_S": running_err_S / len(dataloader),
                        "val_err_Sc": running_err_Sc / len(dataloader),
                        "val_l1_loss": running_l1_loss / len(dataloader),
                        "val_tv_loss": running_tv_loss / len(dataloader),
                        "epoch": epoch + 1,
                    }

                    for key, value in results_dict.items():
                        if key != "epoch":
                            print(f"{key}, {value}")

                    if log == True:
                        wandb.log(results_dict)

                    stopping_loss = (
                        alpha * results_dict["val_err_S"]
                        - (1 - alpha) * results_dict["val_err_Sc"]
                        + results_dict["val_l1_loss"]
                        + results_dict["val_tv_loss"]
                    )

                    if stopping_loss <= best_explainer_loss:
                        best_explainer_loss = stopping_loss
                        if save_path is not None:
                            print("Saving new model")
                            torch.save(
                                self.state_dict(),
                                os.path.join(save_path, model_name),
                            )
"""
