# Copyright 2021 Zhongyang Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import inspect
import torch
import json
import importlib
from torch.nn import functional as F
import torch.optim.lr_scheduler as lrs
import torchvision.utils as vutils
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
import subprocess
from torchmetrics import StructuralSimilarityIndexMeasure
import torchvision.transforms as T
from PIL import Image
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

class MInterface(pl.LightningModule):
    def __init__(self, model_name, loss, lr, **kargs):
        super().__init__()
        self.args = kargs
        # self.device = torch.device(self.args['gpu'] if torch.cuda.is_available() else "cpu")
        self.save_hyperparameters()
        self.load_model()
        self.configure_loss()
        self.all_W_features = []
        self.all_labels = []
        self.all_representations = []  # 初始化列表用于存储表征变量
        self.epoch_counter = 0
        self.latents_list = []
        self.dci_informativeness_train = 0
        self.dci_informativeness_test = 0
        self.dci_disentanglement = 0
        self.dci_completeness = 0
        self.MIG_discrete_mig = 0
        self.SAP_score = 0
        self.modularity_score = 0
        self.explicitness_score_train = 0
        self.explicitness_score_test = 0
        self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
    def forward(self, img):
        return self.model(img,self.epoch_counter) # self.epoch_counter

    def training_step(self, batch, batch_idx):
        x1,x2,label = batch
        z_quant,loss,reconstruction_loss,x_rec = self([x1,x2])
        if batch_idx % 100 == 0:  
            x1, x2, label = batch
            z_quant, loss, reconstruction_loss, x_rec = self([x1, x2])
            n_images = 20
            original = x1[0:n_images, :]
            reconstructed = x_rec[0:n_images, :]
        
            grid = vutils.make_grid(
                torch.cat([original, reconstructed], dim=0),
                nrow=n_images,  
                normalize=True,
                scale_each=True,
                padding=2  
            )
            
            self.logger.experiment.add_image('val_reconstructions', grid, self.current_epoch)

            save_dir = './reconstructions_SQONMFCurve'
            os.makedirs(save_dir, exist_ok=True)

            save_path = os.path.join(save_dir, f"epoch_{self.current_epoch}_batch_{batch_idx}.png")
            vutils.save_image(grid, save_path)

            print(f"Saved image at {save_path}")
        with torch.no_grad():
            cosine = self.calculate_orthogonality(self.model.nmf.weight.data.cpu().numpy())
        self.log('cosine', cosine, on_step=True, on_epoch=False)
        self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    def on_train_batch_end(self,batch, batch_idx,a):
        self.model.nmf.clamp_weights()
        
    def on_train_epoch_end(self):
        self.epoch_counter += 1
        self.model.quantizer.anneal_temperature()
        return super().on_train_epoch_end()
    
    def calculate_similarity_accuracy(self, features_online, features_target):
        features_online = F.normalize(features_online, dim=1)  # [batch_size, feature_dim]
        features_target = F.normalize(features_target, dim=1)  # [batch_size, feature_dim]
        similarity_matrix = torch.mm(features_online, features_target.T)  

        max_similarity_indices = torch.argmax(similarity_matrix, dim=1)  
        correct_matches = (max_similarity_indices == torch.arange(similarity_matrix.size(0), device=features_online.device)).sum().item()
        
        accuracy = correct_matches / similarity_matrix.size(0)
        return accuracy
    
    def compute_and_save_correlation(self, save_path, axis):
        W = self.model.nmf.weight.detach().float().cpu()
        if axis == 'col':
            W = W.t()
        W = W.contiguous().view(W.shape[0], -1)  # [num_vec, dim]

        Wn = F.normalize(W, p=2, dim=1)          # [num_vec, dim]
        corr = (Wn @ Wn.t()).clamp(-1.0, 1.0)    # [num_vec, num_vec], in [-1,1]

        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        labels = [f'v{i}' for i in range(corr.size(0))]
        df = pd.DataFrame(corr.numpy(), index=labels, columns=labels)
        df.to_csv(save_path, float_format='%.6f')

        n = corr.size(0)
        mask = ~torch.eye(n, dtype=bool)
        offdiag_mean_abs = corr.abs()[mask].mean().item()
        return offdiag_mean_abs
    def compute_mean_cosine_similarity(self, W):

        W_norm = F.normalize(W, dim=0)  #  normalize to (512, 256)


        cosine_similarity_matrix = torch.mm(W_norm.T, W_norm)  # shape: (256, 256)

        num_features = cosine_similarity_matrix.size(0)
        mask = ~torch.eye(num_features, dtype=bool, device=W.device)  
        cosine_similarity_values = cosine_similarity_matrix[mask]  

        mean_cosine_similarity = cosine_similarity_values.mean().item()

        return mean_cosine_similarity
    def calculate_orthogonality(self,matrix):
        """
        Calculate the orthogonality measure of a given matrix.
        The measure is based on the squared sum of the off-diagonal elements of the Gram matrix (W^T W).

        Parameters:
        matrix (numpy.ndarray): Input matrix W to evaluate orthogonality.

        Returns:
        float: The orthogonality measure. Smaller values indicate closer to orthogonality.
        """
        # Calculate W^T * W (Gram matrix)
        gram_matrix = np.dot(matrix,matrix.T)
        gram_matrix = gram_matrix/(gram_matrix.shape[0])
        # Calculate the sum of squares of off-diagonal elements
        off_diagonal = (np.sum(gram_matrix) - np.sum(np.diag(gram_matrix)))/np.sum(gram_matrix)
        
        return off_diagonal
    
   
    def validation_step(self,batch, batch_idx):
        
        if self.epoch_counter % 1 == 0 and self.epoch_counter!=0:
        # Here we just reuse the validation_step for testing
            x1,x2,label = batch
            z_quant,loss,reconstruction_loss,x_rec = self([x1,x2])
            latents = z_quant.cpu().numpy()  
            self.latents_list.append(latents)
            
            self.val_ssim.update(x_rec, x1)
            self.log('val_ssim', self.val_ssim, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        else:
            x1,x2,label = batch
            z_quant,loss,reconstruction_loss,x_rec = self([x1,x2])
            pass
    
    
    def test_step(self, batch, batch_idx):

        return self.validation_step(batch, batch_idx)
    
    
    def save_augmented_images(self,x1, x2, batch_idx, save_dir):

        os.makedirs(save_dir, exist_ok=True)
        
        for i in range(x1.size(0)):
            x1_path = os.path.join(save_dir, f"batch_{batch_idx}_view1_img_{i}.png")
            vutils.save_image(x1[i], x1_path)
            
            x2_path = os.path.join(save_dir, f"batch_{batch_idx}_view2_img_{i}.png")
            vutils.save_image(x2[i], x2_path)

    
    # def test_step(self, batch, batch_idx):
    #     # Here we just reuse the validation_step for testing

    #     return self.validation_step(batch, batch_idx)
    
    

    def on_validation_epoch_end(self):
        if self.epoch_counter % 1 == 0 and self.epoch_counter!=0:
            latents = np.concatenate(self.latents_list, axis=0)  # (num_samples, c * rank)
            latents_array = np.expand_dims(latents, axis=-1)
            np.savez(self.args['npz_path'], latents=latents_array)
            self.latents_list.clear()
            # set hyperparameter
            command = [
                'python', '/home/star/Projects/g2/gyh/gyh/DisDiff-main/run_para_metrics_yunhai.py',  
                '-l', os.path.dirname(self.args['npz_path']),  
                '-n', self.args['npz_path'],  
                '-d','shapes3d' # shapes3d mpi3d cars3d
            ]

            result = subprocess.run(command, capture_output=True, text=True)

            print(result.stdout)
            print(result.stderr)

            json_file_path = os.path.dirname(self.args['npz_path'])+ '/dis_metrics/480000.json'  
            with open(json_file_path, 'r') as f:
                result_dict = json.load(f)

            if 'beta_VAE' in result_dict:
                self.log("beta_VAE_train_accuracy", result_dict['beta_VAE']['train_accuracy'],on_epoch=True)
                self.log("beta_VAE_eval_accuracy", result_dict['beta_VAE']['eval_accuracy'],on_epoch=True)

            if 'dci' in result_dict:
                self.dci_informativeness_train = result_dict['dci']['informativeness_train']
                self.dci_informativeness_test = result_dict['dci']['informativeness_test']
                self.dci_disentanglement = result_dict['dci']['disentanglement']
                self.dci_completeness = result_dict['dci']['completeness']
                self.log("dci_informativeness_train", self.dci_informativeness_train,on_epoch=True)
                self.log("dci_informativeness_test", self.dci_informativeness_test,on_epoch=True)
                self.log("dci_disentanglement", self.dci_disentanglement,on_epoch=True)
                self.log("dci_completeness", self.dci_completeness,on_epoch=True)
                dci =  (self.dci_informativeness_test + self.dci_disentanglement + self.dci_completeness)/3
                self.log("dci", float(dci),on_epoch=True)

            if 'MIG' in result_dict:
                self.MIG_discrete_mig = result_dict['MIG']['discrete_mig']
                self.log("MIG_discrete_mig", self.MIG_discrete_mig,on_epoch=True)

            if 'factor_VAE' in result_dict:
                self.log("factor_VAE_train_accuracy", result_dict['factor_VAE']['train_accuracy'],on_epoch=True)
                self.log("factor_VAE_eval_accuracy", result_dict['factor_VAE']['eval_accuracy'],on_epoch=True)
                self.log("factor_VAE_num_active_dims", result_dict['factor_VAE']['num_active_dims'],on_epoch=True)
                
            if 'sap' in result_dict:
                self.SAP_score = result_dict['sap']['SAP_score']
                self.log("SAP_score", self.SAP_score,on_epoch=True)
                
            if 'modularity_explicitness' in result_dict:
                self.modularity_score = result_dict['modularity_explicitness']['modularity_score']
                self.explicitness_score_test = result_dict['modularity_explicitness']['explicitness_score_test']
                self.explicitness_score_train = result_dict['modularity_explicitness']['explicitness_score_train']
                self.log("modularity_score", self.modularity_score,on_epoch=True)
                self.log("explicitness_score_train", self.explicitness_score_train,on_epoch=True)
                self.log("explicitness_score_test", self.explicitness_score_test,on_epoch=True)
            os.remove(json_file_path)
            print(f"Deleted file: {self.args['npz_path']}")
        else:
            self.latents_list.clear()
            dci =  (self.dci_informativeness_test + self.dci_disentanglement + self.dci_completeness)/3
            self.log("dci", dci,on_epoch=True)
            self.log("MIG_discrete_mig", self.MIG_discrete_mig,on_epoch=True)
            self.log("modularity_score", self.modularity_score,on_epoch=True)
            self.log("explicitness_score_test", self.explicitness_score_test,on_epoch=True)
            pass

    def configure_optimizers(self):
        if hasattr(self.hparams, 'weight_decay'):
            weight_decay = self.hparams.weight_decay
        else:
            weight_decay = 0
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=self.hparams.lr if hasattr(self.hparams, 'lr') else 3e-4,  # 默认学习率
            weight_decay=weight_decay
        )

        if self.hparams.lr_scheduler is None:
            return optimizer
        else:
            if self.hparams.lr_scheduler == 'step':
                scheduler = lrs.StepLR(optimizer,
                                       step_size=self.hparams.lr_decay_steps,
                                       gamma=self.hparams.lr_decay_rate)
            elif self.hparams.lr_scheduler == 'cosine':
                scheduler = lrs.CosineAnnealingLR(optimizer,
                                                  T_max=self.hparams.lr_decay_steps,
                                                  eta_min=self.hparams.lr_decay_min_lr)
            else:
                raise ValueError('Invalid lr_scheduler type!')
            return [optimizer], [scheduler]

    def configure_loss(self):
        loss = self.hparams.loss.lower()
        if loss == 'mse':
            self.loss_function = F.mse_loss
        elif loss == 'l1':
            self.loss_function = F.l1_loss
        elif loss == 'bce':
            self.loss_function = F.binary_cross_entropy
        elif loss == 'ce':
            self.loss_function = F.cross_entropy
        else:
            raise ValueError("Invalid Loss Type!")

    def load_model(self):
        name = self.hparams.model_name
        # Change the `snake_case.py` file name to `CamelCase` class name.
        # Please always name your model file name as `snake_case.py` and
        # class name corresponding `CamelCase`.
        # camel_name = ''.join([i.capitalize() for i in name.split('_')])
        try:
            Model = getattr(importlib.import_module(
                '.'+name, package=__package__), 'GoldenRoad')
           
        except:
            raise ValueError(
                f'Invalid Module File Name or Invalid Class Name {name}!')
        self.model = self.instancialize(Model).to(self.args['gpu'])
        # c = self.model.device
        # a =1

    def instancialize(self, Model, **other_args):
        """ Instancialize a model using the corresponding parameters
            from self.hparams dictionary. You can also input any args
            to overwrite the corresponding value in self.hparams.
        """
        class_args = inspect.getargspec(Model.__init__).args[1:]
        inkeys = self.hparams.keys()
        args1 = {}
        for arg in class_args:
            if arg in inkeys:
                args1[arg] = getattr(self.hparams, arg)
        args1.update(other_args)
        return Model(**args1)
