import os
import torch
from torch import nn
from torch.nn import functional as F

import time 

from models.encoder import Encoder
from utils.losses import *

import numpy as np
from tqdm import tqdm
import sklearn
import scipy
import copy
import pdb 
import faiss
import logging
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from utils.datautils import *

from sklearn.manifold import TSNE

from utils import *
      
from torchvision.utils import save_image

from utils.clustering import Kmeans, ClusterDataset
from utils import compute_flops

class USSL(nn.Module):

    def __init__(self, args):
        super().__init__()
        
        self.args = args
    
        self.encoder = Encoder(arch = self.args.arch, 
                          proj_dim = self.args.proj_dim, 
                          input_size = self.args.image_size)
        
        
        self.encoder = self.set_parallel_device(self.encoder)

        logging.info(f"Number of parameters in model {sum(p.numel() for p in self.parameters())}")

        if self.args.problem == 'simclr':
            self.writer = SummaryWriter(log_dir = self.args.root)
            self.encoder_criterion = UConLoss(self.args.temperature).cuda(self.args.gpu)
            self.clustering_criterion = ClusteringLoss(len(self.args.datasets), use_v2_loss = True)
            self.clustering = Kmeans(len(self.args.datasets))
            self.cluster_result = None
            self.outlier_threshold = 1
            self.highest_percs = [0.92] * len(self.args.datasets)
            self.cur_iter = 0
            self.set_loaders()
            self.encoder_optimizer, self.encoder_scheduler = utils.get_optimizer_scheduler(opt = self.args.encoder_opt, 
                                                                     LR = self.args.encoder_lr, 
                                                                     weight_decay = self.args.encoder_weight_decay, 
                                                                     lr_schedule = self.args.encoder_lr_schedule, 
                                                                     warmup = self.args.encoder_warmup, 
                                                                     T_max = self.args.iters, 
                                                                     model = self.encoder, 
                                                                     last_epoch=self.cur_iter-1
                                                                    )
            
#             self.clustering_optimizer, self.clustering_scheduler = utils.get_optimizer_scheduler(opt = self.args.encoder_opt, 
#                                                                      LR = self.args.encoder_lr, 
#                                                                      weight_decay = self.args.encoder_weight_decay, 
#                                                                      lr_schedule = self.args.encoder_lr_schedule, 
#                                                                      warmup = self.args.encoder_warmup, 
#                                                                      T_max = self.args.clustering_iters, 
#                                                                      model = self.encoder, 
#                                                                      last_epoch=self.simclr_warmup-1
#                                                                     )
            
    def set_parallel_device(self, component):
        component = component.cuda(self.args.gpu)
        
        if self.args.dist == 'ddp':
            if self.args.sync_bn:
                component = nn.SyncBatchNorm.convert_sync_batchnorm(component)
            dist.barrier()
            component = DDP(component, [self.args.gpu], find_unused_parameters=True)

        else:
            component = nn.DataParallel(component)
            
        return component
        
    def get_loader(self, dataset, sampler):
        batch_size = self.args.batch_size if len(dataset) > self.args.batch_size else len(dataset)
        loader = get_dataloader(dataset = dataset, 
                                  sampler = sampler, 
                                  batch_size = batch_size,
                                  workers = self.args.workers)
        return loader
    
    def get_samples(self, loaders, datasets, samplers):
        all_images = []
        all_indices = []
        for i in range(len(loaders)):
            if self.args.dist == 'ddp':
                samplers[i].set_epoch(self.cur_iter)

            try:
                batch = next(loaders[i])
            except StopIteration:
                loaders[i] = iter(self.get_loader(datasets[i], samplers[i]))
                batch = next(loaders[i])
            
            images = torch.cat(batch[0], dim = 0).cuda(self.args.gpu)
            all_indices.append(batch[1])
            all_images.append(images)
        
        return all_images, all_indices, loaders
    
    def set_loaders(self):
        self.train_sets = []
        self.test_sets = []
        self.trainsamplers = []
        self.clean_train_sets = []
        for d in range(len(self.args.datasets)):
            d_name = self.args.datasets[d]
            k_shot = None if self.args.k_shot[d] == "None" else int(self.args.k_shot[d])
            
            aug_transform = get_simclr_transform(dataset_name = d_name, 
                                            image_size = self.args.image_size, 
                                            color_dist_s = self.args.color_dist_s, 
                                            scale_lower = self.args.scale_lower, 
                                            use_color_dist = self.args.use_color_dist, 
                                            use_rotation = self.args.use_rotation)

            clean_transform = get_clean_transform(dataset_name = d_name, 
                                                image_size = self.args.image_size)

            transform = ContrastiveLearningViewGenerator(aug_transform, self.args.multiplier, clean_transform)

            train_set, test_set = get_dataset(dataset_name = d_name, 
                                              data_root = self.args.data_root, 
                                              train_transform = transform,
                                              test_transform = clean_transform, 
                                              k_shot = k_shot) 
            
            clean_train_set, _ = get_dataset(dataset_name = d_name, 
                                              data_root = self.args.data_root, 
                                              train_transform = clean_transform,
                                              test_transform = clean_transform, 
                                              k_shot = k_shot) 

            self.train_sets.append(train_set)
            self.clean_train_sets.append(clean_train_set)
            self.test_sets.append(test_set)
        
        
            trainsampler = get_sampler(dataset = train_set, 
                                        dist = self.args.dist)

            self.trainsamplers.append(trainsampler)
            
        self.combined_train_set = CustomConcatDataset(self.train_sets)
        self.combined_trainsampler = get_sampler(dataset = self.combined_train_set, 
                                        dist = self.args.dist)
        
        if self.args.ignore_domain_labels or self.args.combine_datasets:
            self.train_sets = [self.combined_train_set]
            self.trainsamplers = [self.combined_trainsampler]
            
        self.clustering_train_set = CustomConcatDataset(self.clean_train_sets)
        self.clustering_trainsampler = get_sampler(dataset = self.clustering_train_set, 
                                        dist = self.args.dist)
        self.clustering_train_loader = self.get_loader(self.clustering_train_set, self.clustering_trainsampler)
    
    def load(self):
        logging.info(f"Loading checkpoint from {self.args.ckpt}")
        print(f"Loading checkpoint from {self.args.ckpt}")
        map_location = {'cuda:%d' % 0: 'cuda:%d' % self.args.rank}
        state_dict = torch.load(self.args.ckpt, map_location=map_location)

        self.encoder.load_state_dict(state_dict['encoder'])
        self.encoder_optimizer.load_state_dict(state_dict['encoder_optimizer'])
        self.encoder_scheduler.load_state_dict(state_dict['encoder_scheduler'])
        self.cur_iter = state_dict['iter']
        
    def save(self, path):
        ckpt = {
            'encoder':self.encoder.state_dict(),
            'encoder_optimizer':self.encoder_optimizer.state_dict(),
            'encoder_scheduler':self.encoder_scheduler.state_dict(),
            'cluster_result':self.cluster_result,
            'iter':self.cur_iter
        }
        fname = os.path.join(path, 'checkpoint-%d.pth.tar' % self.cur_iter)
        torch.save(ckpt, fname)
    
    def compute_features(self):
        logging.info('Computing features...')
        self.encoder.eval()
        features = torch.zeros(len(self.clustering_train_set), self.encoder.module.proj_dim).cuda(self.args.gpu)
        for i, batch in enumerate(self.clustering_train_loader):
            with torch.no_grad():
                features[batch[1]] = self.encoder(batch[0].cuda(self.args.gpu, non_blocking=True))[1]
                
        if self.args.dist == 'ddp':
            dist.barrier()        
            dist.all_reduce(features, op=dist.ReduceOp.SUM)  

        self.encoder.train()
        return features.cpu()
        
    def eval_clustering(self, cluster_result, ignore_outliers = False):
        label_to_cluster = torch.zeros((len(self.args.datasets), len(self.args.datasets)))
        for _, batch in enumerate(self.clustering_train_loader):
            index = batch[1]
            domain_labels = batch[3]
            
            if ignore_outliers:
                index, _, batch_index = np.intersect1d(cluster_result['select_idx'].cpu().numpy(), index.cpu().numpy(), return_indices = True)
                index = torch.LongTensor(index).cuda(self.args.gpu)
                batch_index = torch.LongTensor(batch_index).cuda(self.args.gpu)
                domain_labels = batch[3][batch_index]
            
            for domain_label in range(len(self.args.datasets)):  
                label_index = index[domain_labels == domain_label]
                
                for cluster_id in range(len(self.args.datasets)):
                    count = torch.nonzero(cluster_result['im2cluster'][label_index] == cluster_id, as_tuple = False).shape[0]
                    label_to_cluster[domain_label][cluster_id] += count
                     
        logging.info(f"Label to Cluster (samples): {label_to_cluster}")
        print(f"Label to Cluster (samples): {label_to_cluster}")
        percs = (label_to_cluster/label_to_cluster.sum(1)).max(0).values.cpu().numpy().tolist()
        logging.info(f"Percentages: {percs}")
        print(f"Percentages: {percs}")
        
        return percs
            

    def train(self, mode = True):
        compute_flops.register_handles(self)
        
        train_loaders = [iter(self.get_loader(self.train_sets[d], self.trainsamplers[d])) for d in range(len(self.train_sets))]
        
        torch.cuda.empty_cache()
        
        simclr_warmup = True
        
        restarted_training = False
        
        clustering_loss = 0
        clustering_acc = 0
        loss = 0 
        loss_aug = 0 
        loss_dataset = 0
        loss_ood = 0
            
        while self.cur_iter < self.args.iters:
            torch.cuda.empty_cache()
            
            self.cur_iter += 1
            
            data_time, it_time = 0, 0
            train_logs = []
            start_time = time.time()       
                    
            images, index, train_loaders = self.get_samples(train_loaders, self.train_sets, self.trainsamplers)
            data_time += time.time() - start_time
            
            features = torch.Tensor([]).cuda(self.args.gpu)
            batch_sizes = []

            if self.args.ignore_domain_labels and (self.cluster_result != None):
                _, z = self.encoder(images[0])
                index = index[0]
                if self.args.remove_outliers:
                    index, _, _ = np.intersect1d(self.cluster_result['select_idx'].cpu().numpy(), index.cpu().numpy(), return_indices = True)
                    index = torch.LongTensor(index).cuda(self.args.gpu)
                
                for dom in range(len(self.args.datasets)):
                    features = torch.cat([features, z[torch.cat([torch.where(self.cluster_result['im2cluster'][index] == dom)[0], torch.where(self.cluster_result['im2cluster'][index] == dom)[0] + z.shape[0] // 2])]])
                    batch_sizes.append(torch.where(self.cluster_result['im2cluster'][index] == dom)[0].shape[0])

            else:
                for d_images in images:
                    torch.cuda.empty_cache()
                    batch_sizes.append(d_images.shape[0]//2)
                    _, z = self.encoder(d_images)
                    torch.cuda.empty_cache()
                    features = torch.cat([features, z], dim = 0)
                    torch.cuda.empty_cache()

            reg_ood = 0 if self.args.ignore_domain_labels and self.cluster_result == None else self.args.ussl_reg_ood
            n_datasets = 1 if (self.args.ignore_domain_labels and self.cluster_result == None) or self.args.combine_datasets else len(self.args.datasets)
            
            
            if features.shape[0] > 0:
                loss, loss_aug, loss_dataset, loss_ood, = self.encoder_criterion(features = features, 
                                                         batch_sizes = batch_sizes,
                                                         n_views = self.args.multiplier, 
                                                         n_datasets = n_datasets,
                                                         reg_aug = self.args.ussl_reg_aug,
                                                         reg_dataset = self.args.ussl_reg_dataset, 
                                                         reg_ood = reg_ood)


                self.encoder_optimizer.zero_grad()
                loss.backward()
                self.encoder_optimizer.step()
                if self.encoder_scheduler is not None:
                    self.encoder_scheduler.step()
                
            
            if self.args.ignore_domain_labels and (self.cur_iter % self.args.clustering_freq == 0):
                cluster_result = self.clustering.cluster(self.compute_features(), self.args.remove_outliers, self.outlier_threshold)
                percs = self.eval_clustering(cluster_result)
                
                if self.args.remove_outliers:
                    self.eval_clustering(cluster_result, ignore_outliers = True)
                    
                compute_new_clusters = True
                for pe in range(len(percs)):
                    if percs[pe] < self.highest_percs[pe]:
                        compute_new_clusters = False
                
                
                if compute_new_clusters:
                    self.highest_percs = percs
                    logging.info(f"Computing new clusters")
                    self.cluster_result = cluster_result
                    if self.args.remove_outliers:
                        self.outlier_threshold *= 0.9
#                         self.eval_clustering(self.cluster_result, ignore_outliers = True)

                    if self.args.rank == 0:
                        self.save(self.args.root)
                        torch.cuda.empty_cache()

#                     self.train_sets = []
#                     self.trainsamplers = []
#                     logging.info("Creating new datasets")
#                     for d in range(len(self.args.datasets)):
#                         train_set = ClusterDataset(torch.where(self.cluster_result['im2cluster'][self.cluster_result['select_idx']] == d)[0], self.combined_train_set)
#                         trainsampler = get_sampler(dataset = train_set, 
#                                                 dist = self.args.dist)
#                         self.train_sets.append(train_set)
#                         self.trainsamplers.append(trainsampler)
#                     logging.info("Done creating new datasets")

#                     logging.info("Creating new loaders")
#                     train_loaders = [iter(self.get_loader(self.train_sets[d], self.trainsamplers[d])) for d in range(len(self.train_sets))]           
                    
                    
                    if self.args.restart_training and not restarted_training:
                        # re-initialize everything
                        
                        restarted_training = True
                        logging.info("Reinitializing encoder and optimizer")
                        self.encoder = Encoder(arch = self.args.arch, 
                          proj_dim = self.args.proj_dim, 
                          input_size = self.args.image_size)
                        self.encoder = self.set_parallel_device(self.encoder)
                        self.cur_iter = 0
                        self.encoder_optimizer, self.encoder_scheduler = utils.get_optimizer_scheduler(opt = self.args.encoder_opt, 
                                                                     LR = self.args.encoder_lr, 
                                                                     weight_decay = self.args.encoder_weight_decay, 
                                                                     lr_schedule = self.args.encoder_lr_schedule, 
                                                                     warmup = self.args.encoder_warmup, 
                                                                     T_max = self.args.iters, 
                                                                     model = self.encoder, 
                                                                     last_epoch=self.cur_iter-1
                                                                    )
                        
                    
            torch.cuda.empty_cache()
            
            
            logs = {
                'loss': loss,
                'loss_aug': loss_aug,
                'loss_dataset': loss_dataset,
                'loss_ood': loss_ood,
                'clustering_loss': clustering_loss,
                'clustering_acc': clustering_acc,
                'encoder_learning_rate': self.encoder_optimizer.param_groups[0]['lr']
            }
            # save logs for the batch
            train_logs.append({k: utils.tonp(v) for k, v in logs.items()})

            it_time += time.time() - start_time

            if (self.cur_iter % self.args.save_freq == 0 or self.cur_iter == self.args.iters):
                if self.args.rank == 0:
                    self.save(self.args.root)
                    torch.cuda.empty_cache()

            if (self.cur_iter % self.args.log_freq == 0 or self.cur_iter == self.args.iters) and self.args.rank == 0:   
                train_logs = utils.agg_all_metrics(train_logs)
                self.writer.add_scalar('encoder_learning_rate', self.encoder_optimizer.param_groups[0]['lr'], global_step=self.cur_iter)
                self.writer.add_scalar('data_time', data_time, global_step=self.cur_iter)
                self.writer.add_scalar('it_time', it_time, global_step=self.cur_iter)
                    
                for key in train_logs:
                    self.writer.add_scalar(key, train_logs[key], global_step=self.cur_iter)

                logging.info(f"Iter: {self.cur_iter}/{self.args.iters}, {train_logs}, Data time: {data_time}, It time: {it_time}, FLOPS: {compute_flops.calculate_flops()}B")

                data_time, it_time = 0, 0
                train_logs = []
                

            if self.cur_iter == self.args.iters:                    
                break

            start_time = time.time()
                
        if self.args.dist == 'ddp':
            dist.destroy_process_group()
            
            
            
#                 _, z = self.encoder(images)
#                 clustering_loss, clustering_acc = self.clustering_criterion(z, batch[1], self.cluster_result)

#                 self.clustering_optimizer.zero_grad()
#                 clustering_loss.backward()
#                 self.clustering_optimizer.step()
#                 if self.clustering_scheduler is not None:
#                     self.clustering_scheduler.step()


#             if self.cluster_result:
#                 _, z = self.encoder(images[0])
#                 for dom in range(len(self.args.datasets)):
#                     features = torch.cat([features, z[torch.cat([torch.where(self.cluster_result['im2cluster'][index] == dom)[0], torch.where(self.cluster_result['im2cluster'][index] == dom)[0] + z.shape[0] // 2])]])
#                     batch_sizes.append(torch.where(self.cluster_result['im2cluster'][index] == dom)[0].shape[0])


#             if self.args.ignore_domain_labels:
#                 if not clustering_mode and self.cur_iter > self.args.simclr_warmup and self.cur_iter - self.args.simclr_warmup < self.args.clustering_iters:
#                     logging.log("CLUSTERING MODE")
#                     print("CLUSTERING MODE")
#                     clustering_mode = True
#                 elif clustering_mode and self.cur_iter - self.args.simclr_warmup >= self.args.clustering_iters:
#                     logging.log("USSL MODE")
#                     print("USSL MODE")
#                     clustering_mode = False
                    

#                 if len(self.previous_centroids) == 2:
#                     self.previous_centroids = self.previous_centroids[1:]
#                 if len(self.previous_centroids) < 2:
#                     self.previous_centroids.append(self.cluster_result['centroids'])

#                 if len(self.previous_centroids) == 2:
#                     for dom in range(len(self.args.datasets)):
#                         prev_centroids_dom = [c[dom].unsqueeze(0) for c in self.previous_centroids]
#                         prev_centroids_dom = F.normalize(torch.cat(prev_centroids_dom, dim = 0), dim = 1)
#                         sim = torch.mm(prev_centroids_dom, prev_centroids_dom.T)
#                 self.cluster_result = None
