import torch.nn.functional as F
import torch
import os
import torch.distributed as dist
from tqdm import tqdm


class CentroidHelper():
    def __init__(self, opt, n_classes, momentum=0.9,):
        self.n_classes = n_classes
        self.opt = opt
        self.device = opt.device
        self.train_mode = opt.train_mode if hasattr(opt, 'train_mode') else False
        
        # Initial Centroids from the large dataset
        self.centroids = torch.zeros((n_classes, n_classes), dtype=torch.float32).to(self.device)
        self.unnormalized_centroids = torch.zeros((n_classes, n_classes), dtype=torch.float32)
        self.initial_centroids = None
        self.momentum = momentum
        self.save_path = "./save/{}/teacher/{}/centroids.pt".format(opt.dataset, opt.model)
        
        self.is_initialed = False
        if os.path.exists(self.save_path):
            self.load_initial_centroids()
            self.is_initialed = True
        else:
            print("==> [Error]: Initial centroid does not exist. Do initializa_centroids() first.")


    def load_initial_centroids(self, save_path=None):
        """Load the centroids from a file."""
        if save_path == None:
            data = torch.load(self.save_path)
        else:
            data = torch.load(save_path)
        self.centroids = data["centroids"].to(self.device)
        self.initial_centroids = data["initial_centroids"].to(self.device) if data["initial_centroids"] is not None else None
        self.unnormalized_centroids = data["unnormalized_centroids"].to(self.device) if data["unnormalized_centroids"] is not None else None


    def save_initial_centroids(self, save_path=None):
        """Save the centroids to a file."""
        save_data = {"centroids": self.centroids.cpu(),
                     "initial_centroids": self.initial_centroids.cpu() if self.initial_centroids is not None else None,
                     "unnormalized_centroids": self.unnormalized_centroids.cpu() if self.unnormalized_centroids is not None else None}
        if save_path == None:
            torch.save(save_data, self.save_path)
        else:
            torch.save(save_data, save_path)


    def initialize_centroids(self, model, train_loader, save=True):
        """Compute initial centroids using the large dataset."""
        print("initialize centroids for {}".format(self.opt.model))
        model.eval()
        self.unnormalized_centroids = torch.zeros((self.n_classes, self.n_classes), dtype=torch.float32, device=self.device)
        
        # Process the data
        with torch.no_grad():
            # for image, target in train_loader:
            for image, target in tqdm(train_loader, desc="Initializing centroids", leave=False):
                image, target = image.to(self.device), target.to(self.device)
                image = image.float()
                output = model(image)
                output = F.softmax(output, dim=1)

                # Update the unnormalized centroids for each class
                for Class in target.unique():
                    self.unnormalized_centroids[Class] += torch.sum(output[target == Class], axis=0)

        # Synchronize centroids across GPUs
        if dist.is_initialized():
            dist.barrier()
            dist.all_reduce(self.unnormalized_centroids, op=dist.ReduceOp.SUM)
            
        # Calculate centroids
        self.centroids = self.unnormalized_centroids / (self.unnormalized_centroids.sum(1)[:,None])
        self.initial_centroids = self.centroids.clone()
        if save:
            self.save_initial_centroids()


    def update_centroids(self, model, mini_loader):
        """Update centroids using the small dataset with momentum.(udpate the centroid)"""
        if self.train_mode:
            model.train()
        else:
            model.eval()
        self.unnormalized_centroids.zero_()

        # Process the data
        with torch.no_grad():
            for image, target in mini_loader:
            # for image, target in tqdm(mini_loader, desc="Updating centroids", leave=False):
                image = image.float()
                image, target = image.to(self.device), target.to(self.device)
                
                output = model(image)
                output = F.softmax(output, dim=1)

                # Update the unnormalized centroids for each class
                for Class in target.unique():
                    self.unnormalized_centroids[Class] += torch.sum(output[target == Class], axis=0)
        
        # Synchronize centroids across GPUs
        if dist.is_initialized():
            dist.barrier()
            dist.all_reduce(self.unnormalized_centroids, op=dist.ReduceOp.SUM)
        else:
            print("Warning: Distributed environment not initialized. Results may not be synchronized across GPUs.")
            
        # Calculate new centroids
        new_centroids = self.unnormalized_centroids / (self.unnormalized_centroids.sum(1)[:,None])
        # print("shape:", self.unnormalized_centroids.sum(1)[:,None][0].item())
        
        if self.is_initialed and not self.opt.dataset == "cifar100":
            self.centroids = self.momentum * self.centroids + (1-self.momentum) * new_centroids
        else:
            self.centroids = new_centroids


    def compute_centroids(self, model, val_loader, save=False):
        """compute centroids using the full dataset. (dont update the centorid)"""
        if self.train_mode:
            model.train()
        else:
            model.eval()
        self.unnormalized_centroids.zero_()

        # Process the data
        with torch.no_grad():
            for image, target in val_loader:
            # for image, target in tqdm(val_loader, desc="Computing Centroids", leave=False):
                image = image.float()
                image, target = image.to(self.device), target.to(self.device)
                
                output = model(image)
                output = F.softmax(output, dim=1)

                # Update the unnormalized centroids for each class
                for Class in target.unique():
                    self.unnormalized_centroids[Class] += torch.sum(output[target == Class], axis=0)
        
        # Synchronize centroids across GPUs
        if dist.is_initialized():
            dist.barrier()
            dist.all_reduce(self.unnormalized_centroids, op=dist.ReduceOp.SUM)
        else:
            print("Warning: Distributed environment not initialized. Results may not be synchronized across GPUs.")
            
        # Calculate new centroids
        new_centroids = self.unnormalized_centroids / (self.unnormalized_centroids.sum(1)[:,None])
        
        if save and os.path.exists(self.opt.trial_folder):
            val_centroid_file = os.path.join(self.opt.trial_folder, 'val_centroid.pt')
            # Save the new centroids to the specified file
            # torch.save(new_centroids, val_centroid_file)
            # print(f"Centroids saved to {val_centroid_file}")
        else:
            print(f"[Error]: Path {self.opt.trial_folder} does not exist.")
    
        return new_centroids, self.unnormalized_centroids


    def get_centroids(self, target):
        return torch.index_select(self.centroids, 0, target).to(target.device)  # [batch_size, num_classes]


    def get_unnormalized_centroids(self, target):
        return torch.index_select(self.unnormalized_centroids, 0, target).to(target.device)

