import os
import logging
import torch
import torch.nn.functional as F

from datasets.data_loading import get_source_loader


logger = logging.getLogger(__name__)


class PrototypesHandler:
    def __init__(self,  cfg, model_forward_func, num_classes, device, src_data_preprocess, force_protos_recalculation=False, protos_dir=None, only_mean=True, categorical_protos=True, 
                 global_protos=False, save_protos=True, projector=False, num_projector_layers=None, projection_dim=None,
                 max_num_samples_for_protos=50000):
        
        self.cfg = cfg
        self.dataset_name = cfg.CORRUPTION.DATASET
        self.projector = projector
        self.model_forward_func = model_forward_func
        self.only_mean = only_mean
        self.num_projector_layers = num_projector_layers
        self.projection_dim = projection_dim
        self.protos_dir = protos_dir
        self.categorical_protos = categorical_protos
        self.global_protos = global_protos
        self.save_protos = save_protos
        self.num_classes = num_classes
        self.device = device
        self.max_num_samples_for_protos = max_num_samples_for_protos
        self.force_protos_recalculation = force_protos_recalculation
        self.src_data_preprocess = src_data_preprocess
        
        self.categorical_mu = None
        self.categorical_cov = None 
        self.global_mu = None 
        self.global_cov = None
        
        if self.protos_dir is None:
            self.protos_dir = os.path.join(self.cfg.CKPT_DIR, "prototypes")
            
        if self.projector and (self.num_projector_layers is None or self.projection_dim is None):
            raise ValueError("Provide prototype parameters to load the right prototypes when projector is used!")
            
    
    def get_prototypes(self):
        means_path, covs_path, mean_global_path, cov_global_path = self.get_proto_paths()
        
        need_calculating_protos = False
        
        if self.categorical_protos and os.path.exists(means_path) and os.path.exists(covs_path) and not self.force_protos_recalculation:
            logger.info("Loading class-wise source prototypes...")
            self.categorical_mu = torch.load(means_path)
            self.categorical_cov = torch.load(covs_path)
        else:
            need_calculating_protos = True
            
        if self.global_protos and os.path.exists(mean_global_path) and os.path.exists(cov_global_path) and not self.force_protos_recalculation:
            logger.info("Loading global source prototype...")
            self.categorical_mu = torch.load(mean_global_path)
            self.categorical_cov = torch.load(cov_global_path)
        else:
            need_calculating_protos = True


        if need_calculating_protos:
            self.calculate_prototypes()
            if self.save_protos:
                self.save_prototypes(means_path, covs_path, mean_global_path, cov_global_path)
                
        self.handle_protos_device()
        
        print(self.categorical_mu.device)
                
        return self.categorical_mu, self.categorical_cov, self.global_mu, self.global_cov
                
    def calculate_prototypes(self):
        logger.info("Extracting source prototypes...")
        os.makedirs(self.protos_dir, exist_ok=True)
        self.load_src_loader()

        feat_stack = [[] for i in range(self.num_classes)]
        with torch.no_grad():
            for batch_idx, data in enumerate(self.src_loader):
                inputs, labels = data[0], data[1]

                feat = self.model_forward_func(inputs.to(self.device))

                feat = F.normalize(feat, dim=-1).cpu()

                for label in labels.unique():
                    feat_stack[label].extend(feat[label==labels, :])

                if batch_idx * feat.shape[0] > self.max_num_samples_for_protos:
                   break

            ext_mu = []
            ext_cov = []
            ext_all = []

            for feat in feat_stack:
                ext_mu.append(torch.stack(feat).mean(dim=0))
                ext_cov.append(torch.cov(torch.stack(feat).T))
                ext_all.extend(feat)

            ext_all = torch.stack(ext_all)
            ext_all_mu = ext_all.mean(dim=0)
            ext_all_cov = torch.cov(ext_all.T)

            ext_mu = torch.stack(ext_mu)
            ext_cov = torch.stack(ext_cov)
        
        # TODO: this way is easier, but possibly slower (calculate both sets of protos for every case)
        if self.global_protos:
            self.global_mu = ext_all_mu.squeeze() 
            if not self.only_mean:
                self.global_cov = ext_all_cov.squeeze()
        
        if self.categorical_protos:
            self.categorical_mu = ext_mu 
            if not self.only_mean:
                self.categorical_cov = ext_cov 

    def load_src_loader(self):
        batch_size_src = self.cfg.TEST.BATCH_SIZE if self.cfg.TEST.BATCH_SIZE > 1 else self.cfg.TEST.WINDOW_LENGTH
        _, self.src_loader = get_source_loader(dataset_name=self.cfg.CORRUPTION.DATASET,
                                            adaptation='source', # ignore method specific transforms
                                            preprocess=self.src_data_preprocess,
                                            data_root_dir=self.cfg.DATA_DIR,
                                            batch_size=batch_size_src,
                                            ckpt_path=self.cfg.MODEL.CKPT_PATH,
                                            num_samples=self.cfg.SOURCE.NUM_SAMPLES,
                                            percentage=self.cfg.SOURCE.PERCENTAGE,
                                            workers=min(self.cfg.SOURCE.NUM_WORKERS, os.cpu_count()))
        self.src_loader_iter = iter(self.src_loader)
        
    def handle_protos_device(self):
        if 'cuda' in self.device:
            all_proto_vars = [self.categorical_mu, self.categorical_cov, self.global_mu, self.global_cov]
            try:
                for i in range(len(all_proto_vars)):
                    if all_proto_vars[i] is not None:
                        all_proto_vars[i] = all_proto_vars[i].to(self.device)
            except RuntimeError as e: 
                if 'out of memory' in str(e):
                    logger.info("OOM! Prototypes stored on CPU!")
                    torch.cuda.empty_cache()
                    for i in range(len(all_proto_vars)):
                        if all_proto_vars[i] is not None:
                            all_proto_vars[i] = all_proto_vars[i].cpu()

            # reasign the copies to original variables 
            self.categorical_mu, self.categorical_cov, self.global_mu, self.global_cov = all_proto_vars

    def save_prototypes(self, means_path, covs_path, mean_global_path, cov_global_path):
        if self.categorical_mu is not None:
            torch.save(self.categorical_mu, means_path)
        
        if self.categorical_cov is not None:
            torch.save(self.categorical_cov, covs_path)
        
        if self.global_mu is not None:
            torch.save(self.global_mu, mean_global_path)
        
        if self.global_cov is not None:
            torch.save(self.global_cov, cov_global_path)
    
    def get_proto_paths(self):
        ckpt_path = self.cfg.MODEL.CKPT_PATH
        arch_name = self.cfg.MODEL.ARCH
        if self.projector:
            if self.dataset_name == "domainnet126":
                fname_means = f"protos_proj{self.num_projector_layers}lays_{self.projection_dim}dim_means_{self.dataset_name}_{ckpt_path.split(os.sep)[-1].split('_')[1]}.pth"
                fname_covs = f"protos_proj{self.num_projector_layers}lays_{self.projection_dim}dim_covs_{self.dataset_name}_{ckpt_path.split(os.sep)[-1].split('_')[1]}.pth"
                fname_mean_global = f"protos_proj{self.num_projector_layers}lays_{self.projection_dim}dim_global_mean_{self.dataset_name}_{ckpt_path.split(os.sep)[-1].split('_')[1]}.pth"
                fname_cov_global = f"protos_proj{self.num_projector_layers}lays_{self.projection_dim}dim_global_cov_{self.dataset_name}_{ckpt_path.split(os.sep)[-1].split('_')[1]}.pth"
            else:
                fname_means = f"protos_proj{self.num_projector_layers}lays_{self.projection_dim}dim_means_{self.dataset_name}_{arch_name}.pth"
                fname_covs = f"protos_proj{self.num_projector_layers}lays_{self.projection_dim}dim_covs_{self.dataset_name}_{arch_name}.pth"
                fname_mean_global = f"protos_proj{self.num_projector_layers}lays_{self.projection_dim}dim_global_mean_{self.dataset_name}_{arch_name}.pth"
                fname_cov_global = f"protos_proj{self.num_projector_layers}lays_{self.projection_dim}dim_global_cov_{self.dataset_name}_{arch_name}.pth"
        else:
            if self.dataset_name == "domainnet126":
                fname_means = f"protos_means_{self.dataset_name}_{ckpt_path.split(os.sep)[-1].split('_')[1]}.pth"
                fname_covs = f"protos_covs_{self.dataset_name}_{ckpt_path.split(os.sep)[-1].split('_')[1]}.pth"
                fname_mean_global = f"protos_global_mean_{self.dataset_name}_{ckpt_path.split(os.sep)[-1].split('_')[1]}.pth"
                fname_cov_global = f"protos_global_cov_{self.dataset_name}_{ckpt_path.split(os.sep)[-1].split('_')[1]}.pth"
            else:
                fname_means = f"protos_means_{self.dataset_name}_{arch_name}.pth"
                fname_covs = f"protos_covs_{self.dataset_name}_{arch_name}.pth"
                fname_mean_global = f"protos_global_mean_{self.dataset_name}_{arch_name}.pth"
                fname_cov_global = f"protos_global_cov_{self.dataset_name}_{arch_name}.pth"
            
        means_path = os.path.join(self.protos_dir, fname_means)
        covs_path = os.path.join(self.protos_dir, fname_covs)
        mean_global_path = os.path.join(self.protos_dir, fname_mean_global)
        cov_global_path = os.path.join(self.protos_dir, fname_cov_global)
        
        return means_path, covs_path, mean_global_path, cov_global_path

