import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import MinkowskiEngine as ME
from sklearn.metrics import jaccard_score
from tqdm import tqdm
import csv
import pickle
import open3d as o3d
from knn_cuda import KNN
from sklearn.metrics import davies_bouldin_score

from utils.losses import CELoss, SoftCELoss, DICELoss, SoftDICELoss, HLoss, SCELoss, SoftDICELoss_withweight_upper
from utils.collation import CollateSeparated, CollateStream
from utils.sampler import SequentialSampler
from utils.dataset_online import PairedOnlineDataset, FrameOnlineDataset
from models import MinkUNet18_HEADS, MinkUNet18_SSL, MinkUNet18_MCMC
import time
from pytorch3d.ops import knn_points, knn_gather
import math
import faiss
import copy
from copy import deepcopy
from sklearn.mixture import GaussianMixture
import copy 


from tent import copy_model_and_optimizer, load_model_and_optimizer, softmax_entropy
import tent
import torch.backends.cudnn as cudnn

def configure_model(model, eps, momentum, reset_stats, no_stats):
    """Configure model for adaptation by test-time normalization."""
    for m in model.modules():
        if isinstance(m, nn.BatchNorm1d):
            # use batch-wise statistics in forward
            m.train()
            # configure epsilon for stability, and momentum for updates
            m.eps = eps
            m.momentum = momentum
            if reset_stats:
                # reset state to estimate test stats without train stats
                m.reset_running_stats()
            if no_stats:
                # disable state entirely and use only batch stats
                m.track_running_stats = False
                m.running_mean = None
                m.running_var = None
    return model



prototypes = torch.zeros(14, 96).cuda()
Label_Bank = []
Predict_Bank = []
Backbone_Bank = []
Coordinate_Bank = []

Label_Bank_ada = []
Predict_Bank_ada = []
Coordinate_Bank_ada = []
EVAL_FRAME = 0
TRAIN_FRAME = 0
score_list = []

def configure_model(model, eps, momentum, reset_stats, no_stats):
    """Configure model for adaptation by test-time normalization."""
    for m in model.modules():
        if isinstance(m, nn.BatchNorm1d):
            # use batch-wise statistics in forward
            m.train()
            # configure epsilon for stability, and momentum for updates
            m.eps = eps
            m.momentum = momentum
            if reset_stats:
                # reset state to estimate test stats without train stats
                m.reset_running_stats()
            if no_stats:
                # disable state entirely and use only batch stats
                m.track_running_stats = False
                m.running_mean = None
                m.running_var = None
    return model



def diceCoeffv2(pred, gt, eps=1e-5, ignore_label=-1):
    r""" computational formula
        dice = (2 * tp) / (2 * tp + fp + fn)
    """
    if ignore_label is not None:
        valid_idx = torch.logical_not(gt == ignore_label)
        gt = gt[valid_idx]
        pred = pred[valid_idx, :]

    gt = F.one_hot(gt, num_classes=pred.shape[1])
    pred = F.softmax(pred, dim=-1)
    C = gt.size(-1)
 
    tp = torch.sum(gt * pred, dim=0)
    fp = torch.sum(pred, dim=0) - tp
    fn = torch.sum(gt, dim=0) - tp
    loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)
    return loss.sum() / C


@torch.jit.script
def softmax_mean_entropy(x: torch.Tensor) -> torch.Tensor:
    """Mean entropy of softmax distribution from logits."""
    x = x.softmax(1).mean(0)
    return -(x * torch.log(x)).sum()

@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)

def get_cbst_th_2(preds, vals, p=0.5):
    pc = torch.unique(preds)
    c_th = torch.zeros(pc.max()+1)
    for c in pc:
        c_idx = preds == c
        vals_c, _ = torch.sort(vals[c_idx], descending=False)
        c_th[c] = vals_c[torch.floor(torch.tensor((vals_c.shape[0]-1)*p)).long()]
    return c_th

def get_cbst_th_unknown(preds, vals, p=0.01):
    pc = torch.unique(preds)
    c_th = torch.zeros(pc.max()+1)
    for c in pc:
        c_idx = preds == c
        vals_c, _ = torch.sort(vals[c_idx], descending=False)
        c_th[c] = vals_c[torch.floor(torch.tensor((vals_c.shape[0]-1)*p)).long()]
    return c_th


def label_onehot(inputs, num_class):
    '''
    inputs is class label
    return one_hot label 
    dim will be increasee
    '''
    inputs = torch.relu(inputs)
    outputs = torch.zeros([inputs.shape[0], num_class]).to(inputs.device)
    return outputs.scatter_(1, inputs.unsqueeze(1), 1.0)


def known_prototype_update(rep, label, prototypes, score, num_segments,  threshold=0.0, prototype_alpha=0.99):
    num_segments = num_segments
    alpha = prototype_alpha

    for i in range(num_segments): #7
        valid_pixel_gather = label == i
        if valid_pixel_gather.sum() == 0:
            continue

        with torch.no_grad():
            if threshold != 0.0:
                score_class = score.squeeze()[valid_pixel_gather.bool()]
                score_class = score_class[:, i]
                threshold_score  = torch.topk(score_class, int(valid_pixel_gather.sum() * threshold))[0][-1]
                valid_pixel_gather[valid_pixel_gather.bool()] = valid_pixel_gather[valid_pixel_gather.bool()] * (score_class > threshold_score).long()
            proto_rep_ = torch.mean((rep[valid_pixel_gather.bool()]), dim=0, keepdim=True)
            if (prototypes[i, :].sum() == torch.tensor(0.0)):
                prototypes[i, :] = proto_rep_
            else:
                prototypes[i, :] = alpha * prototypes[i, :] + (1 - alpha) * proto_rep_

    return prototypes


class EMA(object):
    def __init__(self, model, alpha):
        self.step = 0
        self.model = model
        self.alpha = alpha

    def update(self, model):
        # decay = min(1 - 1 / (self.step + 1), self.alpha)
        decay = self.alpha

        # for ema_param, param in zip(self.model.parameters(), model.parameters()):
        #     ema_param.data = decay * ema_param.data + (1 - decay) * param.data

        with torch.no_grad():
            for ema_v, model_v in zip(self.model.state_dict().values(), model.state_dict().values()):
                ema_v.copy_(decay * ema_v.data + (1 - decay) * model_v.data)


        # if self.step > 1000:
        #     for ema_param, param in zip(self.model.parameters(), model.parameters()):
        #         ema_param.data = decay * ema_param.data + (1 - decay) * param.data
        # else:
        #     pass
        self.step += 1



class OneDomainAdaptation(object):
    r"""
    Segmentation Module for MinkowskiEngine for training on one domain.
    """

    def __init__(self,
                 model,
                 eval_dataset,
                 adapt_dataset,
                 source_model=None,
                 source_model2=None,
                 optimizer_name='SGD',
                 criterion='CELoss',
                 epsilon=0.,
                 ssl_criterion='Cosine',
                 ssl_beta=0.5,
                 seg_beta=1.0,
                 temperature=0.5,
                 lr=1e-3,
                 stream_batch_size=1,
                 adaptation_batch_size=2,
                 weight_decay=1e-5,
                 momentum=0.8,
                 val_batch_size=6,
                 train_num_workers=10,
                 val_num_workers=10,
                 num_classes=7,
                 clear_cache_int=2,
                 scheduler_name='ExponentialLR',
                 pseudor=None,
                 use_random_wdw=False,
                 freeze_list=None,
                 delayed_freeze_list=None,
                 num_mc_iterations=10,
                 use_global=False,
                 args=None):

        super().__init__()

        for name, value in list(vars().items()):
            if name != "self":
                setattr(self, name, value)

        if self.use_global:
            print('--> USING GLOBAL FEATS IN CONTRASTIVE!')

        if criterion == 'CELoss':
            self.criterion = CELoss(ignore_label=self.adapt_dataset.ignore_label,
                                    weight=None)

        elif criterion == 'WCELoss':
            self.criterion = CELoss(ignore_label=self.adapt_dataset.ignore_label,
                                    weight=self.adapt_dataset.weights)

        elif criterion == 'SoftCELoss':
            self.criterion = SoftCELoss(ignore_label=self.adapt_dataset.ignore_label)

        elif criterion == 'DICELoss':
            self.criterion = DICELoss(ignore_label=self.adapt_dataset.ignore_label)
        elif criterion == 'SoftDICELoss':
            self.criterion = SoftDICELoss(ignore_label=self.adapt_dataset.ignore_label,
                                          neg_range=True, eps=self.epsilon)

        elif criterion == 'SCELoss':
            self.criterion = SCELoss(alpha=1, beta=0.1, num_classes=self.num_classes, ignore_label=self.adapt_dataset.ignore_label)
        else:
            raise NotImplementedError

        if self.ssl_criterion == 'CosineSimilarity':
            self.ssl_criterion = nn.CosineSimilarity(dim=-1)
        else:
            raise NotImplementedError

        self.ignore_label = self.eval_dataset.ignore_label

        self.configure_optimizers()

        self.global_step = 0

        self.device = None
        self.max_time_wdw = self.eval_dataset.max_time_wdw

        self.delayed_freeze_list = delayed_freeze_list

        self.entropy = HLoss()

        self.pseudor = pseudor

        self.topk_matches = 0

        self.dataset_name = self.adapt_dataset.name

        self.knn_search = KNN(k=200, transpose_mode=True)

        self.symmetric_ce = SCELoss(alpha=1, beta=0.1, num_classes=7)

        self.prototypes = torch.zeros(num_classes, 96).cuda()
        self.known_prototypes = torch.zeros(num_classes, 96).cuda()

        self.memory_bank = []
        self.memory_score = []
        self.memory_tscore = []
        self.memory_num = 20

        self.criterion_test = nn.CrossEntropyLoss(ignore_index=self.adapt_dataset.ignore_label)

    def freeze(self):
        # here we freeze parts that have to be frozen forever
        if self.freeze_list is not None:
            for name, p in self.model.named_parameters():
                for pf in self.freeze_list:
                    if pf in name:
                        p.requires_grad = False

    def delayed_freeze(self, frame):
        # here we freeze parts that have to be frozen only for a certain period
        if self.delayed_freeze_list is not None:
            for name, p in self.model.named_parameters():
                for pf, frame_act in self.delayed_freeze_list.items():
                    if pf in name and frame <= frame_act:
                        p.requires_grad = False


    def adaptation_double_pseudo_step(self, batch, frame=0, weights_save_path=None):

        self.model.train()
        self.freeze()
        self.source_model.eval()

        coords = batch["coordinates_all"][0]


        global Label_Bank
        global Predict_Bank
        global Backbone_Bank
        global Coordinate_Bank


        ####################KNN_matches####################
        global_pts = batch['global_pts0'].float().cuda()
        global_next_pts = batch['global_pts1'].float().cuda()
        global_pts_knn = global_pts.unsqueeze(0).float().cuda()
        global_next_pts_knn = global_next_pts.unsqueeze(0).float().cuda()
        dists, idx, nn = knn_points(global_pts_knn, global_next_pts_knn, K=1, return_nn=True) # [1, N, K], [1, N, K], [1, N, K, 3]
        dists = dists.squeeze()
        matches_knn = torch.cat((torch.arange(global_pts.shape[0]).unsqueeze(1).cuda(), idx.squeeze(0)), dim=1)
        matches_knn = matches_knn[dists < 0.7]

        matches0 = matches_knn[:, 0]
        matches1 = matches_knn[:, 1]

        batch['matches0'] = matches0
        batch['matches1'] = matches1
        ####################KNN_matches####################


        batch_all = torch.zeros([coords.shape[0], 1])
        coords_all = torch.cat([batch_all, coords], dim=-1)
        feats_all = torch.ones([coords_all.shape[0], 1]).float()
        
        # we assume that data the loader gives frames in pairs
        stensor_all = ME.SparseTensor(coordinates=coords_all.int().to(self.device),
                                     features=feats_all.to(self.device),
                                     quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)

        # t0 = time.perf_counter()
        with torch.no_grad():

            if self.pseudor.metric in ['entropy', 'confidence']:
                out_source, source_feats, _ = self.source_model(stensor_all, is_train=False)
                out_source = out_source.cpu()
            elif self.pseudor.metric in ['mcmc', 'mcmc_cbst', 'mcmc_cbst_easy2hard']:
                if self.args.use_ema:
                    self.ema.model.eval()
                    _, source_feats, _ = self.ema.model(stensor_all, is_train=False)
                    self.ema.model.dropout.train()
                    out_source = []
                    if self.args.use_temporal == False:
                        for i in range(self.num_mc_iterations):
                            out_tmp, _, _ = self.ema.model(stensor_all, is_train=False)
                            out_tmp_clone = out_tmp.clone()
                            out_tmp = F.softmax(out_tmp.cpu(), dim=-1)
                            out_source.append(out_tmp.view([out_tmp.shape[0], 1, -1]))

                        out_source = torch.cat(out_source, dim=1)
                            
                    out_tmp_without_dropout, out_backbone, _ = self.ema.model(stensor_all, is_train=False, without_dropout=True)
                    
                    if self.args.entropy_loss:
                        out_tmp_without_dropout_2, out_backbone, _ = self.source_model2(stensor_all, is_train=False, without_dropout=True)
                else:
                    self.source_model.eval()
                    _, source_feats, _ = self.source_model(stensor_all, is_train=False)
                    self.source_model.dropout.train()
                    out_source = []
                    if self.args.use_temporal == False:
                        for i in range(self.num_mc_iterations):
                            out_tmp, _, _ = self.source_model(stensor_all, is_train=False)
                            out_tmp_clone = out_tmp.clone()
                            out_tmp = F.softmax(out_tmp.cpu(), dim=-1)
                            out_source.append(out_tmp.view([out_tmp.shape[0], 1, -1]))

                        out_source = torch.cat(out_source, dim=1)
                        
                    out_tmp_without_dropout, out_backbone, _ = self.source_model(stensor_all, is_train=False, without_dropout=True)
            else:
                raise NotImplementedError
            

        if self.args.use_temporal:
            source_label = F.softmax(out_tmp_without_dropout, dim=-1)
            pseudo_logits_rep, pseudo_labels_rep = torch.max(source_label, dim=1)
            pseudo_all = pseudo_labels_rep.clone().detach().cpu()
            pseudo_all_logits = pseudo_logits_rep.clone().detach().cpu()

            pseudo0 = pseudo_labels_rep.clone().detach()


            if len(Label_Bank) == 0:
                Label_Bank.append(pseudo0.clone())
                Predict_Bank.append(out_tmp_without_dropout.clone())
                Backbone_Bank.append(out_backbone.clone())
                Coordinate_Bank.append(batch["global_pts0"].unsqueeze(0).float().cuda().clone())  #[1,N,3]
                previous_label = pseudo0.clone()
            else:
                Label_Bank.append(pseudo0.clone())
                Predict_Bank.append(out_tmp_without_dropout.clone())
                Backbone_Bank.append(out_backbone.clone())
                Coordinate_Bank.append(batch["global_pts0"].unsqueeze(0).float().cuda().clone())  #[1,N,3]
                if len(Label_Bank) > self.args.pre_label_num:
                    Label_Bank.pop(0)
                    Predict_Bank.pop(0)
                    Backbone_Bank.pop(0)
                    Coordinate_Bank.pop(0)
                
                previous_label_list = []
                for i in range(len(Label_Bank)-1):
                    global_pts0 = Coordinate_Bank[-1].clone()
                    global_pts1 = Coordinate_Bank[i].clone()
                    K = self.args.pre_label_knn
                    dists, idx, _ = knn_points(global_pts0, global_pts1, K=K, return_nn=True) # [1, N, K], [1, N, K], [1, N, K, 3]
                    previous_label = knn_gather(Label_Bank[i].unsqueeze(0).unsqueeze(-1).clone().cuda(), idx) #[1, N, K, C] 
                    previous_label = previous_label.squeeze(-1) #[N, K]
                    previous_label_list.append(previous_label)
                previous_label_list = torch.cat(previous_label_list, dim=0) #[pre_label_num, N, K]
                previous_label_list = previous_label_list.permute(1, 0, 2) #[N, pre_label_num, K]
                N = previous_label_list.shape[0]
                previous_label_list = previous_label_list.reshape(N, -1) #[N, pre_label_num*K]
                # 取出现次数最多的label
                # torch one-hot
                previous_label_list = previous_label_list.reshape(-1)
                previous_label_list = previous_label_list + 1
                previous_label_list = label_onehot(previous_label_list, self.num_classes+1) #[N*pre_label_num*K, 8]
                previous_label_list = previous_label_list.reshape(N, -1, self.num_classes+1) #[N, pre_label_num*K, 8]
                previous_label_list = previous_label_list.sum(dim=1) #[N, 8]
                previous_label_list = previous_label_list.argmax(dim=-1) #[N]
                previous_label_list = previous_label_list - 1
                previous_label = previous_label_list

            pseudo0[previous_label != pseudo0] = -1


        else:
            batch['model_features0'] = source_feats.cpu()
            pseudo0, _ = self.pseudor.get_pseudo(out_source, batch, frame, return_metric=True)

            source_label = F.softmax(out_tmp_without_dropout, dim=-1)
            pseudo_logits_rep, pseudo_labels_rep = torch.max(source_label, dim=1)
            pseudo_all = pseudo_labels_rep.clone().detach().cpu()
            pseudo_all_logits = pseudo_logits_rep.clone().detach().cpu()


            Label_Bank.append(pseudo0.clone())
            Predict_Bank.append(out_tmp_without_dropout.clone())
            Backbone_Bank.append(out_backbone.clone())
            Coordinate_Bank.append(batch["global_pts0"].unsqueeze(0).float().cuda().clone())  #[1,N,3]
            if len(Label_Bank) > self.args.pre_label_num:
                Label_Bank.pop(0)
                Predict_Bank.pop(0)
                Backbone_Bank.pop(0)
                Coordinate_Bank.pop(0)


        if self.args.superpoint:
            if self.dataset_name == 'nuScenes':
                if self.args.superpoint_multi_preframe:
                    save_path_root = 'XXXXXXXXXX/OSTTA-GOO/OOD_result/nuscenes_dbscan_preframe3/' + self.sequence + '/'
                else:
                    save_path_root = 'XXXXXXXXXX/OSTTA-GOO/OOD_result/nuscenes_dbscan/' + self.sequence + '/'

            else:
                if self.args.superpoint_multi_preframe:
                    save_path_root = 'XXXXXXXXXX/OSTTA-GOO/OOD_result/kitti_dbscan_preframe3/'
                else:
                    save_path_root = 'XXXXXXXXXX/OSTTA-GOO/OOD_result/kitti_dbscan/'


            if frame >= 3:
                dbscan_data = np.load(save_path_root +  str(frame) + ".npy", allow_pickle=True)
                dbscan_data = torch.from_numpy(dbscan_data).long().cuda()

                # dbscan_data[dbscan_data == -1] = max(dbscan_data) + 1
                mask = dbscan_data != -1
                dbscan_data = dbscan_data[mask]
                dbscan_data = F.one_hot(dbscan_data) # [N, M]
                dbscan_data = dbscan_data[:, :].T # [M, N]

                if self.args.superpoint_multi_preframe:
                    pre_label = torch.cat(Predict_Bank[::-1], dim=0)[mask]
                    # self_frame_num = Predict_Bank[-1].shape[0]
                    self_frame_num = mask[:Predict_Bank[-1].shape[0]].sum()
                else:
                    pre_label = out_tmp_without_dropout[mask] # [N,C]

                dbscan_data = dbscan_data.unsqueeze(-1) # [M,, N, 1]
                pre_label = pre_label.unsqueeze(0) # [1, N, C]
                dbscan_data_mask = dbscan_data != 0 # [M, N, 1]

                dbscan_data = dbscan_data[dbscan_data_mask.sum(1).squeeze() != 0, :, :] # [M, N, 1]
                dbscan_data_mask = dbscan_data_mask[dbscan_data_mask.sum(1).squeeze() != 0, :, :] # [M, N, 1]

                pre_label = dbscan_data * pre_label # [M, N, C]
                predict = pre_label.argmax(-1) # [M, N]
                predict_onehot = F.one_hot(predict, num_classes=pre_label.shape[-1]).float() # [M, N, C]
                pre_knn_label = (predict_onehot * dbscan_data_mask).sum(1) / dbscan_data_mask.sum(1) # [M, C]
                dbscan_pre_region_purity = 1.0 - torch.sum(-pre_knn_label * torch.log(pre_knn_label + 1e-6), dim=-1) / math.log(pre_knn_label.shape[-1]) # [M]

                source_tmp = F.softmax(pre_label, dim=-1) # [M, N, C]
                p = source_tmp
                dbscan_pre_region_logit = 1.0 - torch.sum(-p * torch.log(p + 1e-6), dim=-1) / math.log(pre_knn_label.shape[-1])  #[M, N]
                dbscan_pre_region_logit = dbscan_pre_region_logit * dbscan_data_mask.squeeze() # [M, N]
                dbscan_pre_region_logit = dbscan_pre_region_logit.sum(-1) / dbscan_data_mask.squeeze().sum(-1) # [M]

                score = dbscan_pre_region_purity * dbscan_pre_region_logit


                if self.args.superpoint_score == 'entropy':
                    score = dbscan_pre_region_logit
                    
                if self.args.superpoint_score == 'purity':
                    score = dbscan_pre_region_purity

                if self.args.superpoint_score == 'margin':
                    source_tmp = F.softmax(pre_label, dim=-1) # [M, N, C]
                    superpoint_mean = source_tmp.sum(1) / dbscan_data_mask.sum(1) 
                    superpoint_sort = superpoint_mean.sort(dim=-1, descending=False)[0]
                    score = superpoint_sort[:, -1] - superpoint_sort[:, -2]

                if self.args.superpoint_score == 'margin+purity':
                    source_tmp = F.softmax(pre_label, dim=-1) # [M, N, C]
                    superpoint_mean = source_tmp.sum(1) / dbscan_data_mask.sum(1) 
                    superpoint_sort = superpoint_mean.sort(dim=-1, descending=False)[0]
                    score = superpoint_sort[:, -1] - superpoint_sort[:, -2]
                    score = score * dbscan_pre_region_purity

                if self.args.superpoint_score == 'MSP':
                    source_tmp = F.softmax(pre_label, dim=-1) # [M, N, C]
                    superpoint_mean = source_tmp.sum(1) / dbscan_data_mask.sum(1) 
                    superpoint_sort = superpoint_mean.sort(dim=-1, descending=False)[0]
                    score = superpoint_sort[:, -1]

                index = score.argsort()

                if self.args.Gaussian:
                    gm_input = score[index]
                    gm = GaussianMixture(n_components=2).fit(gm_input.detach().cpu().numpy().reshape(-1, 1))
                    filter_ids = gm.predict(gm_input.detach().cpu().numpy().reshape(-1, 1))
                    filter_ids = filter_ids if gm.means_[0, 0] < gm.means_[1, 0] else 1 - filter_ids
                    topk = int(np.where(filter_ids == 1)[0][0] / self.args.Gaussian_rate)
                else:
                    if self.dataset_name == 'nuScenes':
                        topk = math.ceil(index.shape[0] * 0.2)
                    else:
                        topk = int(index.shape[0] * 0.1)

                index_sample = index[:topk]

                if self.args.known_prototype:
                    if self.args.Gaussian:
                        known_topk = int((index.shape[0] - np.where(filter_ids == 1)[0][0]) / self.args.known_prototype_gaussian_rate)
                        # known_topk = int(index.shape[0] * 0.1)
                    else:
                        known_topk = int(index.shape[0] * 0.1)

                    if self.args.superpoint_multi_preframe:
                        out_backbone_feature = torch.cat(Backbone_Bank[::-1], dim=0)[mask]
                    else:
                        out_backbone_feature = out_backbone[mask] 
                    out_backbone_feature = out_backbone_feature.unsqueeze(0) # [1, N, 96]

                    out_backbone_feature = dbscan_data_mask * out_backbone_feature # [M, N, 96]
                    out_backbone_feature = out_backbone_feature.sum(1) / dbscan_data_mask.sum(1) # [M, 96]

                    known_index_sample = index[-known_topk:]
                    known_feature = out_backbone_feature[known_index_sample] #[kM, 96]
                    unknown_feature = out_backbone_feature[index_sample] #[uM, 96]

                    known_pre_knn_label = pre_knn_label[known_index_sample]
                    known_logits_rep, known_labels_rep = torch.max(known_pre_knn_label, dim=1)
                    score = known_logits_rep
                    self.known_prototypes = known_prototype_update(known_feature, known_labels_rep, self.known_prototypes.cuda(), score, num_segments=self.num_classes)

                    known_feature = F.normalize(self.known_prototypes.clone(), dim=-1).permute(1, 0) # [96, kM]
                    # known_feature = F.normalize(known_feature, dim=-1).permute(1, 0) # [96, kM]
                    unknown_feature = F.normalize(unknown_feature, dim=-1) # [uM, 96]
                    sim_mat = torch.mm(unknown_feature, known_feature) # [uM, kM]
                    unknown_values, _ = torch.max(sim_mat, dim=-1) # [uM]

                    index_sample = index_sample[unknown_values < self.args.known_prototype_threshold]


                if self.args.superpoint_multi_preframe:
                    index_mask = dbscan_data_mask.squeeze(-1)[:, :self_frame_num][index_sample] #[M, N]
                    new_pseudo_unknown = index_mask.sum(0).long().cuda().detach()

                    new_pseudo_unknown_tmp = torch.zeros_like(pseudo0)
                    new_pseudo_unknown_tmp[mask[:Predict_Bank[-1].shape[0]]] = new_pseudo_unknown
                    new_pseudo_unknown = new_pseudo_unknown_tmp
                else:
                    index_mask = dbscan_data_mask.squeeze(-1)[index_sample] #[M, N]
                    new_pseudo_unknown = index_mask.sum(0).long().cuda().detach()

                    new_pseudo_unknown_tmp = torch.zeros_like(pseudo0)
                    new_pseudo_unknown_tmp[mask] = new_pseudo_unknown
                    new_pseudo_unknown = new_pseudo_unknown_tmp

                # new_pseudo_unknown[pseudo0 != -1] = 0
                pseudo0[new_pseudo_unknown[:pseudo0.shape[0]]==1] = -1

            else:
                pass


        if (pseudo0 != -1).sum() > 0:
            # we assume that data the loader gives frames in pairs
            stensor0 = ME.SparseTensor(coordinates=batch["coordinates0"].int().to(self.device),
                                       features=batch["features0"].to(self.device),
                                       quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)

            stensor1 = ME.SparseTensor(coordinates=batch["coordinates1"].int().to(self.device),
                                       features=batch["features1"].to(self.device),
                                       quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)

            # Must clear cache at regular interval
            if self.global_step % self.clear_cache_int == 0:
                torch.cuda.empty_cache()

            self.optimizer.zero_grad()

            # forward in mink
            out_seg0, out_en0, out_pred0, out_bck0, _, out_seg1, out_en1, out_pred1, out_bck1, _ = self.model((stensor0, stensor1))

            # segmentation loss for t0
            labels0 = batch['labels0'].long()

            loss_seg_head = self.criterion(out_seg0, pseudo0)

            pseudo0 = pseudo0.cpu()
            labels0 = labels0.cpu()

            # get matches in t0 and t1 (used for selection)
            matches0 = batch['matches0'].to(self.device)
            matches1 = batch['matches1'].to(self.device)
            if not self.use_global:
                # 2FUTURE CONTRASTIVE
                # forward preds (t0 -> t1)
                future_preds = torch.index_select(out_pred0, 0, matches0)
                # forward gt feats and stop grad
                future_gt = torch.index_select(out_en1.detach(), 0, matches1)
                future_neg_cos_sim = -self.ssl_criterion(future_preds, future_gt).cpu()

                if self.topk_matches > 0:
                    # select top-k worst performing matches
                    future_neg_cos_sim = future_neg_cos_sim.topk(self.topk_matches, dim=0).values.mean()
                else:
                    future_neg_cos_sim = future_neg_cos_sim.mean(dim=0)

                # 2PAST CONTRASTIVE
                # backward preds (t1 -> t0)
                past_preds = torch.index_select(out_pred1, 0, matches1)
                # backward gt feats and stop grad
                past_gt = torch.index_select(out_en0.detach(), 0, matches0)
                past_neg_cos_sim = -self.ssl_criterion(past_preds, past_gt).cpu()
                if self.topk_matches > 0:
                    # select top-k worst performing matches
                    past_neg_cos_sim = past_neg_cos_sim.topk(self.topk_matches, dim=0).values.mean()
                else:
                    past_neg_cos_sim = past_neg_cos_sim.mean(dim=0)
            else:
                # 2FUTURE CONTRASTIVE
                # forward preds (t0 -> t1)
                future_preds = out_pred0.mean(dim=0)
                # forward gt feats and stop grad
                future_gt = out_en1.detach().mean(dim=0)
                future_neg_cos_sim = -self.ssl_criterion(future_preds, future_gt).cpu()

                # 2PAST CONTRASTIVE
                # backward preds (t1 -> t0)
                past_preds = out_pred1.mean(dim=0)
                # backward gt feats and stop grad
                past_gt = out_en0.detach().mean(dim=0)
                past_neg_cos_sim = -self.ssl_criterion(past_preds, past_gt).cpu()

            # sum up to total
            ssl_loss = (future_neg_cos_sim + past_neg_cos_sim) * self.ssl_beta

            total_loss = self.seg_beta * loss_seg_head + ssl_loss.cuda()

            if self.args.entropy_loss:
                entropy_loss = self.entropy(out_seg0)
                if True:
                    out_tmp_pls_drop = F.softmax(out_tmp_clone, dim=-1).max(dim=-1)[0]
                    if self.args.use_ema:
                        out_tmp_pls = F.softmax(out_tmp_without_dropout_2, dim=-1).max(dim=-1)[0]
                    else:
                        out_tmp_pls = F.softmax(out_tmp_without_dropout, dim=-1).max(dim=-1)[0]
                    out_seg0_pls = F.softmax(out_seg0, dim=-1).max(dim=-1)[0]

                    entropy_loss = entropy_loss[out_tmp_pls < out_seg0_pls]
                    entropy_loss = entropy_loss.mean()
                else:
                    entropy_loss = entropy_loss.mean()
                entropy_loss = entropy_loss.mean()
                total_loss += self.args.entropy_loss_beta * entropy_loss

            if self.args.unient_loss:
                entropys = softmax_entropy(out_seg0)

                max_cos_sim = F.softmax(out_seg0, dim=-1).max(dim=-1)[0]
                min_value = max_cos_sim.min()
                max_value = max_cos_sim.max()
                max_cos_sim = (max_cos_sim - min_value) / (max_value - min_value)
                score = 1 - max_cos_sim

                if False:
                    gm = GaussianMixture(n_components=2).fit(score.detach().cpu().numpy().reshape(-1, 1))
                    weight = gm.predict_proba(score.detach().cpu().numpy().reshape(-1, 1))
                    weight = weight if gm.means_[0, 0] < gm.means_[1, 0] else 1 - weight
                    entropys_ind = entropys.mul(torch.from_numpy(weight[:, 0]).to(entropys.device))
                    loss = entropys_ind.mean(0)
                    entropys_ood = entropys.mul(torch.from_numpy(weight[:, 1]).to(entropys.device))
                    loss -= entropys_ood.mean(0)
                else:
                    gm = GaussianMixture(n_components=2).fit(score.detach().cpu().numpy().reshape(-1, 1))
                    filter_ids = gm.predict(score.detach().cpu().numpy().reshape(-1, 1))
                    filter_ids = filter_ids if gm.means_[0, 0] < gm.means_[1, 0] else 1 - filter_ids
                    entropys_ind = entropys[filter_ids == 0]
                    loss = entropys_ind.mean(0)
                    entropys_ood = entropys[filter_ids == 1]
                    loss -=  entropys_ood.mean(0)

                total_loss += self.args.unient_loss_beta * loss 


            if self.args.unknown_label_loss:
                if self.dataset_name == 'nuScenes' and frame <= 3:
                    pass
                else:
                    hard_label_unk = torch.zeros((new_pseudo_unknown==1).sum(), out_tmp_without_dropout.shape[-1])
                    hard_label_unk += 1.0
                    hard_label_unk = hard_label_unk / (torch.sum(hard_label_unk, dim=-1, keepdim=True) + 1e-4)
                    hard_label_unk = hard_label_unk.to(self.device)

                    pred_cls = F.softmax(out_seg0, dim=-1)
                    pred_cls = pred_cls[new_pseudo_unknown==1]
                    psd_pred_loss = torch.sum(-hard_label_unk * torch.log(pred_cls + 1e-5), dim=-1).mean()
                    total_loss += self.args.unknown_label_loss_beta * psd_pred_loss


            # backward and optimize
            total_loss.backward()
            self.optimizer.step()


            if self.args.use_ema:
                self.ema.update(self.model)

        else:
            # if no pseudo we skip the frame (happens never basically)
            # we assume that data the loader gives frames in pairs
            stensor0 = ME.SparseTensor(coordinates=batch["coordinates0"].int().to(self.device),
                                       features=batch["features0"].to(self.device),
                                       quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)

            stensor1 = ME.SparseTensor(coordinates=batch["coordinates1"].int().to(self.device),
                                       features=batch["features1"].to(self.device),
                                       quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)

            # Must clear cache at regular interval
            if self.global_step % self.clear_cache_int == 0:
                torch.cuda.empty_cache()

            self.model.eval()
            with torch.no_grad():
                # forward in mink
                out_seg0, out_en0, out_pred0, out_bck0, _, out_seg1, out_en1, out_pred1, out_bck1, _ = self.model((stensor0, stensor1))

            # segmentation loss for t0
            labels0 = batch['labels0'].long()

            loss_seg_head = self.criterion(out_seg0, pseudo0)

            pseudo0 = pseudo0.cpu()
            labels0 = labels0.cpu()

            # get matches in t0 and t1 (used for selection)
            matches0 = batch['matches0'].to(self.device)
            matches1 = batch['matches1'].to(self.device)

            ####################KNN_matches####################
            global_pts = batch['global_pts0'].float().cuda()
            global_next_pts = batch['global_pts1'].float().cuda()
            global_pts_knn = global_pts.unsqueeze(0).float().cuda()
            global_next_pts_knn = global_next_pts.unsqueeze(0).float().cuda()
            dists, idx, nn = knn_points(global_pts_knn, global_next_pts_knn, K=1, return_nn=True) # [1, N, K], [1, N, K], [1, N, K, 3]
            dists = dists.squeeze()
            matches_knn = torch.cat((torch.arange(global_pts.shape[0]).unsqueeze(1).cuda(), idx.squeeze(0)), dim=1)
            matches_knn = matches_knn[dists < 0.7]

            matches0 = matches_knn[:, 0]
            matches1 = matches_knn[:, 1]

            batch['matches0'] = matches0
            batch['matches1'] = matches1
            ####################KNN_matches####################

            # 2FUTURE CONTRASTIVE
            # forward preds (t0 -> t1)
            future_preds = torch.index_select(out_pred0, 0, matches0)
            # forward gt feats and stop grad
            future_gt = torch.index_select(out_en1.detach(), 0, matches1)
            future_neg_cos_sim = -self.ssl_criterion(future_preds, future_gt).cpu()

            if self.topk_matches > 0:
                # select top-k worst performing matches
                future_neg_cos_sim = future_neg_cos_sim.topk(self.topk_matches, dim=0).values.mean()
            else:
                future_neg_cos_sim = future_neg_cos_sim.mean(dim=0)

            # 2PAST CONTRASTIVE
            # backward preds (t1 -> t0)
            past_preds = torch.index_select(out_pred1, 0, matches1)
            # backward gt feats and stop grad
            past_gt = torch.index_select(out_en0.detach(), 0, matches0)
            past_neg_cos_sim = -self.ssl_criterion(past_preds, past_gt).cpu()
            if self.topk_matches > 0:
                # select top-k worst performing matches
                past_neg_cos_sim = past_neg_cos_sim.topk(self.topk_matches, dim=0).values.mean()
            else:
                past_neg_cos_sim = past_neg_cos_sim.mean(dim=0)

            # sum up to total
            ssl_loss = (future_neg_cos_sim + past_neg_cos_sim) * self.ssl_beta

        # increase step
        self.global_step += self.stream_batch_size

        # additional metrics
        _, pred_seg0 = out_seg0.detach().max(1)
        # iou
        iou_tmp = jaccard_score(pred_seg0.cpu().numpy(), labels0.cpu().numpy(), average=None,
                                labels=np.arange(0, self.num_classes),
                                zero_division=0.0)

        # forward preds (t0 -> t1)
        future_preds = torch.index_select(out_bck0.detach(), 0, matches0)
        # forward gt feats and stop grad
        future_gt = torch.index_select(out_bck1.detach(), 0, matches1)
        frame_match_sim = -self.ssl_criterion(future_preds, future_gt).mean()

        # we check pseudo labelling accuracy, not IoU as union of points changes
        valid_idx_pseudo = torch.logical_and(pseudo0 != -1, labels0 != -1)
        pseudo_acc = (pseudo0[valid_idx_pseudo] == labels0[valid_idx_pseudo]).sum() / labels0[valid_idx_pseudo].shape[0]

        present_labels, class_occurs = np.unique(labels0.cpu().numpy(), return_counts=True)
        present_labels = present_labels[present_labels != self.ignore_label]
        present_names = self.adapt_dataset.class2names[present_labels].tolist()
        present_names = [os.path.join('training', p + '_iou') for p in present_names]

        results_dict = dict(zip(present_names, iou_tmp[present_labels].tolist()))

        # check pseudo nums
        valid_pseudo = (pseudo0 != -1)
        pseudo_classes, pseudo_num = torch.unique(pseudo0[valid_pseudo], return_counts=True)
        pseudo_names = self.adapt_dataset.class2names[pseudo_classes].tolist()
        classes_count = dict(zip(pseudo_names, pseudo_num.int().tolist()))
        classes_print = dict()
        for c in self.adapt_dataset.class2names[pseudo_classes]:
            if c in classes_count.keys():
                classes_print[f'training/pseudo_number/{c}'] = classes_count[c]
            else:
                classes_print[f'training/pseudo_number/{c}'] = -1

        results_dict.update(classes_print)
        # degeneration check
        out_en0_dg = out_en0.detach().clone()
        out_en1_dg = out_en1.detach().clone()

        max_val = 1/np.sqrt(out_en0_dg.shape[-1])

        out_en0_dg = F.normalize(out_en0_dg, p=2, dim=-1).std(dim=-1).mean()
        out_en1_dg = F.normalize(out_en1_dg, p=2, dim=-1).std(dim=-1).mean()

        results_dict['training/seg_loss'] = loss_seg_head
        results_dict['training/future_ssl'] = future_neg_cos_sim
        results_dict['training/past_ssl'] = past_neg_cos_sim
        results_dict['training/frame_similarity'] = frame_match_sim
        results_dict['training/future_degeneration'] = out_en0_dg
        results_dict['training/past_degeneration'] = out_en1_dg
        results_dict['training/max_degeneration'] = max_val
        results_dict['training/iou'] = np.mean(iou_tmp[present_labels])
        results_dict['training/lr'] = self.optimizer.param_groups[0]["lr"]
        results_dict['training/pseudo_accuracy'] = pseudo_acc
        results_dict['training/pseudo_number'] = torch.sum(pseudo_num)
        # results_dict['training/mean_mcmc'] = mean_metric
        # results_dict['training/source_similarity'] = source_sim

        return results_dict

    def validation_step(self, batch, is_source=False, save_path=None, frame=None, ada_eval=False):
        self.model.eval()
        # for multiple dataloaders
        phase = 'validation' if not is_source else 'source'
        coords_name = 'coordinates'
        feats_name = 'features'
        label_name = 'labels'

        OOD_labels = batch['OOD_labels']

        if save_path is not None:
            save_path_tmp = os.path.join(save_path, phase)

        # clear cache at regular interval
        if self.global_step % self.clear_cache_int == 0:
            torch.cuda.empty_cache()

        # sparsify
        stensor = ME.SparseTensor(coordinates=batch[coords_name].int().to(self.device),
                                  features=batch[feats_name].to(self.device))

        # get output
        out, out_bck, out_bottle = self.model(stensor, is_train=False)

        labels = batch[label_name].long()
        present_lbl = torch.unique(labels)

        loss = self.criterion(out, labels)
        _, preds = out.max(1)

        preds = preds.cpu()
        labels = labels.cpu()
        self.global_step += self.stream_batch_size

        # eval iou and log
        iou_tmp = jaccard_score(preds.numpy(), labels.numpy(), average=None,
                                labels=np.arange(0, self.num_classes),
                                zero_division=0.)

        valid_feats_idx = torch.where(labels != -1)[0].view(-1).long()
        db_index = davies_bouldin_score(out_bck.cpu()[valid_feats_idx].numpy(), labels.cpu()[valid_feats_idx].numpy())

        present_labels, class_occurs = np.unique(labels.cpu().numpy(), return_counts=True)
        present_labels = present_labels[present_labels != self.ignore_label]
        present_names = self.adapt_dataset.class2names[present_labels].tolist()
        present_names = [os.path.join(phase, p + '_iou') for p in present_names]

        results_dict = dict(zip(present_names, iou_tmp[present_labels].tolist()))

        results_dict[f'{phase}/loss'] = loss.cpu().item()
        results_dict[f'{phase}/iou'] = np.mean(iou_tmp[present_labels])
        results_dict[f'{phase}/db_index'] = db_index

        if save_path is not None:

            self.save_pcd(batch, preds.cpu().numpy(),
                          labels.cpu().numpy(), save_path_tmp, frame,
                          is_global=True)
            
        if ada_eval:
            if self.args.superpoint_test_val and ada_eval and frame > 5:
                point_score = point_score
            else:
                point_score = torch.ones_like(OOD_labels)
            return results_dict,  OOD_labels.cpu(), out.detach().cpu(), out_bck.detach().cpu(), out_bottle.F.detach().cpu(), point_score.cpu()
        else:
            return results_dict,  OOD_labels.cpu(), out.detach().cpu(), out_bck.detach().cpu(), out_bottle.F.detach().cpu()

    def configure_optimizers(self):

        parameters = self.model.parameters()

        if self.scheduler_name is None:
            if self.optimizer_name == 'SGD':
                optimizer = torch.optim.SGD(parameters,
                                            lr=self.lr,
                                            momentum=self.momentum,
                                            weight_decay=self.weight_decay)
            elif self.optimizer_name == 'Adam':
                optimizer = torch.optim.Adam(parameters,
                                             lr=self.lr,
                                             weight_decay=self.weight_decay)
            else:
                raise NotImplementedError

            self.optimizer = optimizer
            self.scheduler = None

        else:
            if self.optimizer_name == 'SGD':
                optimizer = torch.optim.SGD(parameters,
                                            lr=self.lr,
                                            momentum=self.momentum,
                                            weight_decay=self.weight_decay)
            elif self.optimizer_name == 'Adam' or self.optimizer_name == 'ADAM':
                optimizer = torch.optim.Adam(parameters,
                                             lr=self.lr,
                                             weight_decay=self.weight_decay)
            else:
                raise NotImplementedError

            if self.scheduler_name == 'CosineAnnealingLR':
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
            elif self.scheduler_name == 'ExponentialLR':
                scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
            elif self.scheduler_name == 'CyclicLR':
                scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=self.lr/10000, max_lr=self.lr,
                                                              step_size_up=5, mode="triangular2")
            else:
                raise NotImplementedError
            self.optimizer = optimizer
            self.scheduler = scheduler

    def get_online_dataloader(self, dataset, is_adapt=False):
        if is_adapt:
            collate = CollateSeparated(torch.device('cpu'))
            sampler = SequentialSampler(dataset, is_adapt=True, adapt_batchsize=self.adaptation_batch_size,
                                        max_time_wdw=self.max_time_wdw)
            dataloader = DataLoader(dataset,
                                    collate_fn=collate,
                                    sampler=sampler,
                                    pin_memory=False,
                                    num_workers=self.train_num_workers)
        else:
            # collate = CollateFN(torch.device('cpu'))
            collate = CollateStream(torch.device('cpu'))
            sampler = SequentialSampler(dataset, is_adapt=False, adapt_batchsize=self.stream_batch_size)
            dataloader = DataLoader(dataset,
                                    collate_fn=collate,
                                    sampler=sampler,
                                    pin_memory=False,
                                    num_workers=self.train_num_workers)
        return dataloader

    def save_pcd(self, batch, preds, labels, save_path, frame, is_global=False):
        pcd = o3d.geometry.PointCloud()

        if not is_global:
            pts = batch['coordinates']
            pcd.points = o3d.utility.Vector3dVector(pts[:, 1:])
        else:
            pts = batch['global_points'][0]
            pcd.points = o3d.utility.Vector3dVector(pts)
        if self.num_classes == 7 or self.num_classes == 2:
            pcd.colors = o3d.utility.Vector3dVector(self.eval_dataset.color_map[labels+1])
        else:
            pcd.colors = o3d.utility.Vector3dVector(self.eval_dataset.color_map[labels+1])

        # os.makedirs(os.path.join(save_path, 'gt'), exist_ok=True)
        # o3d.io.write_point_cloud(os.path.join(save_path, 'gt', str(frame)+'.ply'), pcd)
        if self.num_classes == 7 or self.num_classes == 2:
            pcd.colors = o3d.utility.Vector3dVector(self.eval_dataset.color_map[preds+1])
        else:
            pcd.colors = o3d.utility.Vector3dVector(self.eval_dataset.color_map[preds+1])

        os.makedirs(os.path.join(save_path, 'preds'), exist_ok=True)
        o3d.io.write_point_cloud(os.path.join(save_path, 'preds', str(frame)+'.ply'), pcd)


class OnlineTrainer(object):

    def __init__(self,
                 pipeline,
                 collate_fn_eval=None,
                 collate_fn_adapt=None,
                 device='cpu',
                 default_root_dir=None,
                 weights_save_path=None,
                 loggers=None,
                 save_checkpoint_every=2,
                 source_checkpoint=None,
                 student_checkpoint=None,
                 boost=True,
                 save_predictions=False,
                 is_double=True,
                 is_pseudo=True,
                 use_mcmc=True,
                 sub_epochs=0,
                 args=None):

        super().__init__()

        if device is not None:
            self.device = torch.device(f'cuda:{device}')
        else:
            self.device = torch.device('cpu')

        self.default_root_dir = default_root_dir
        self.weights_save_path = weights_save_path
        self.loggers = loggers
        self.save_checkpoint_every = save_checkpoint_every
        self.source_checkpoint = source_checkpoint
        self.student_checkpoint = student_checkpoint

        self.pipeline = pipeline
        self.pipeline.device = self.device

        self.is_double = is_double
        self.use_mcmc = use_mcmc
        self.model = self.pipeline.model

        if self.is_double:
            self.source_model = self.pipeline.source_model
            self.source_model2 = self.pipeline.source_model2

        self.eval_dataset = self.pipeline.eval_dataset
        self.adapt_dataset = self.pipeline.adapt_dataset

        self.max_time_wdw = self.eval_dataset.max_time_wdw

        self.eval_dataset.eval()
        self.adapt_dataset.train()

        self.online_sequences = np.arange(self.adapt_dataset.num_sequences())
        self.num_frames = len(self.eval_dataset)

        self.collate_fn_eval = collate_fn_eval
        self.collate_fn_adapt = collate_fn_adapt
        self.collate_fn_eval.device = self.device
        self.collate_fn_adapt.device = self.device

        self.sequence = -1

        self.adaptation_results_dict = {s: [] for s in self.online_sequences}
        self.source_results_dict = {s: [] for s in self.online_sequences}

        # for speed up
        self.eval_dataloader = None
        self.adapt_dataloader = None

        self.boost = boost

        self.save_predictions = save_predictions

        self.is_pseudo = is_pseudo
        self.sub_epochs = sub_epochs
        self.num_classes = self.pipeline.num_classes

        self.dataset_name = self.pipeline.dataset_name
        self.args = args
        self.class2names = np.array(['car', 'bicycle', 'motorcycle',  'truck', 'other-vehicle', 'person',
                        'bicyclist', 'motorcyclist',
                        'road', 'parking', 'sidewalk', 'other-ground',
                        'building', 'fence', 'vegetation', 'trunk',
                        'terrain', 'pole', 'traffic-sign'])


    def adapt_double(self):

        self.eval_AUPR_list = []
        self.eval_AUROC_list = []
        self.eval_FPR95_list = []
        self.train_AUPR_list = []
        self.train_AUROC_list = []
        self.train_FPR95_list = []
        # self.online_sequences = np.arange(52, 53)

        self.load_source_model()
        if self.args.use_ema:
            self.pipeline.ema = EMA(self.pipeline.source_model,0.999)

        # first we eval getting performance of source model
        self.eval(is_adapt=True)

        with open(os.path.join(self.args.save_dir, 'OOD_final.txt'),'a') as f:    #设置文件对象
            f.write('Eval Softmax Adaptation AUPR is: ' + str(np.mean(self.eval_AUPR_list)) + '\n')                 #将字符串写入文件中
            f.write('Eval Softmax AUROC is: ' + str(np.mean(self.eval_AUROC_list)) + '\n')                 #将字符串写入文件中
            f.write('Eval Softmax FPR95 is: ' + str(np.mean(self.eval_FPR95_list)) + '\n')                 #将字符串写入文件中

        # adapt
        for sequence in tqdm(self.online_sequences, desc='Online Adaptation'):
            # load source model
            self.reload_model()
            # self.reload_model_from_scratch()
            # set sequence in dataset, in weight path and loggers
            self.set_sequence(sequence)

            global Label_Bank
            global Predict_Bank
            global Backbone_Bank
            global Coordinate_Bank
            Label_Bank = []
            Predict_Bank = []
            Backbone_Bank = []
            Coordinate_Bank = []
            self.pipeline.sequence = self.sequence

            if self.args.use_image_method:
                sequence_dict = self.image_online_adaptation_routine()
                self.adaptation_results_dict[sequence] = sequence_dict

            else:
                # adapt on sequence
                sequence_dict = self.online_adaptation_routine()
                self.adaptation_results_dict[sequence] = sequence_dict


        with open(os.path.join(self.args.save_dir, 'OOD_final.txt'),'a') as f:    #设置文件对象
            f.write('Train Softmax Adaptation AUPR is: ' + str(np.mean(self.train_AUPR_list)) + '\n')                 #将字符串写入文件中
            f.write('Train Softmax AUROC is: ' + str(np.mean(self.train_AUROC_list)) + '\n')                 #将字符串写入文件中
            f.write('Train Softmax FPR95 is: ' + str(np.mean(self.train_FPR95_list)) + '\n')                 #将字符串写入文件中


        self.save_final_results()

    def eval(self, is_adapt=False):
        # load model only once
        self.reload_model(is_adapt=False)
        for sequence in tqdm(self.online_sequences, desc='Online Evaluation', leave=True):
            # set sequence
            self.set_sequence(sequence)
            # evaluate
            sequence_dict = self.online_evaluation_routine()
            # store dict
            self.source_results_dict[sequence] = sequence_dict
        # if not is_adapt:
        #     self.save_eval_results()
        self.save_eval_results()

    def check_frame(self, fr):
        return (fr+1) >= self.pipeline.adaptation_batch_size and fr >= self.max_time_wdw

    def online_adaptation_routine(self):
        # move to device
        self.model.to(self.device)

        if self.is_double:
            self.source_model.to(self.device)
            self.source_model2.to(self.device)

        # for storing
        adaptation_results = []
        unknown_label_list = []
        out_list = []
        out_bck_list = []
        point_score_list = []
        global Label_Bank
        global Predict_Bank
        global Backbone_Bank
        global Coordinate_Bank

        if self.save_predictions:
            save_path = os.path.join(self.weights_save_path, 'pcd')
        else:
            save_path = None


        if self.args.OOD_eval_single:
            AUPR_list = []
            AUROC_list = []
            FPR95_list = []

        for f in tqdm(range(len(self.eval_dataset)), desc=f'Seq: {self.sequence}', leave=True):

            # get eval batch (1 frame at a time)
            val_batch = self.get_evaluation_batch(f)
            # eval
            with torch.no_grad():
                val_dict, unknown_label, out, out_bck, out_bottle, point_score = self.pipeline.validation_step(val_batch, save_path=save_path, frame=f, ada_eval=True)

                unknown_label_list.append(unknown_label)
                point_score_list.append(point_score)
                if self.args.unknown_predict:
                    pass
                    # out_list.append(out_unknown)
                else:
                    out_list.append(out)
                if self.args.OOD_type == 'Decoupling_MaxLogit':
                    out_bck_list.append(out_bck)
            
                if self.args.OOD_eval_test:
                    unknown_label_list.append(unknown_label)
                    point_score_list.append(point_score)
                    if self.args.unknown_predict:
                        pass
                        # out_list.append(out_unknown)
                    else:
                        out_list.append(out)
                    if self.args.OOD_type == 'Decoupling_MaxLogit':
                        out_bck_list.append(out_bck)

            val_dict['validation/frame'] = f
            # log
            self.log(val_dict)

            # if enough frames
            if self.check_frame(f):
                train_dict = {}
                # get adaptation batch (t-b, t)
                # print('FRAME', f)
                batch = self.get_adaptation_batch(f)

                for _ in range(self.sub_epochs):

                    if self.is_pseudo:
                        if self.args.kitti_result_save_npy or self.args.nusce_result_save_npy:
                            train_dict, out_tmp_without_dropout = self.pipeline.adaptation_double_pseudo_step(batch, f, weights_save_path=self.weights_save_path)
                        else:
                            train_dict = self.pipeline.adaptation_double_pseudo_step(batch, f, weights_save_path=self.weights_save_path)
                    else:
                        raise NotImplementedError

                    if train_dict is not None:
                        train_dict.update(train_dict)
                    # log
                self.log(train_dict)

                if self.args.kitti_result_save_npy:
                    save_np_root = os.path.join(self.weights_save_path, 'OOD_NPY')
                    os.makedirs(save_np_root, exist_ok=True)
                    save_np_path = os.path.join(save_np_root, str(f) + '.npy')
                    save_dict = {'coordinates': val_batch['coordinates'].cpu().numpy(), 'global_points': val_batch['global_points'][0].cpu().numpy(), 'labels': val_batch['labels'].cpu().numpy(), 'OOD_labels': val_batch['OOD_labels'].cpu().numpy(), 'out': out_tmp_without_dropout.cpu().numpy(), 'cls_out': out.cpu().numpy()}
                    np.save(save_np_path, save_dict)

                if self.args.nusce_result_save_npy:
                    save_np_root = os.path.join(self.weights_save_path, 'OOD_NPY', self.sequence)
                    os.makedirs(save_np_root, exist_ok=True)
                    save_np_path = os.path.join(save_np_root, str(f) + '.npy')
                    save_dict = {'coordinates': val_batch['coordinates'].cpu().numpy(), 'global_points': val_batch['global_points'][0].cpu().numpy(), 'labels': val_batch['labels'].cpu().numpy(), 'OOD_labels': val_batch['OOD_labels'].cpu().numpy(), 'out': out_tmp_without_dropout.cpu().numpy(), 'cls_out': out.cpu().numpy()}
                    np.save(save_np_path, save_dict)

            else:
                if self.args.use_temporal:
                    pseudo_logits_rep, pseudo_labels_rep = torch.max(out, dim=1)
                    pseudo0 = pseudo_labels_rep.clone().detach().cpu()
                    Label_Bank.append(pseudo0.clone())
                    Predict_Bank.append(out.clone().cuda())
                    Backbone_Bank.append(out_bck.clone().cuda())
                    Coordinate_Bank.append(val_batch["global_points"][0].unsqueeze(0).float().cuda().clone())  #[1,N,3]
                    if len(Label_Bank) > self.args.pre_label_num:
                        Label_Bank.pop(0)
                        Predict_Bank.pop(0)
                        Backbone_Bank.pop(0)
                        Coordinate_Bank.pop(0)

            # if (f+1) % self.save_checkpoint_every == 0:
            #     # save weights
            #     self.save_state_dict(f)

            # append dict
            adaptation_results.append(val_dict)

        if self.args.OOD_type == 'Softmax':
            unknown_labels = torch.cat(unknown_label_list, dim=0)
            out = torch.cat(out_list, dim=0)
            point_score = torch.cat(point_score_list, dim=0)

            out = out[unknown_labels != -1]
            point_score = point_score[unknown_labels != -1]
            unknown_labels = unknown_labels[unknown_labels != -1]

            from sklearn.metrics import precision_recall_curve, auc, roc_curve, roc_auc_score
            unknown_labels = unknown_labels.cpu().numpy()
            softmax_layer = torch.nn.Softmax(dim=1)

            if self.args.unknown_predict:
                uncertainty_scores_softmax = softmax_layer(out)[:, 1]
            else:
                # uncertainty_scores_softmax = torch.max(softmax_layer(out), dim=1)[0]
                uncertainty_scores_softmax = 1 - torch.max(softmax_layer(out), dim=1)[0]

            uncertainty_scores_softmax *= point_score

            if (unknown_labels == 1).sum() == 0:
                with open(os.path.join(self.args.save_dir, 'OOD.txt'),'a') as f:    #设置文件对象
                    f.write('Train Softmax Adaptation AUPR is: ' + "NaN" + '\n')                 #将字符串写入文件中
                    f.write('Train Softmax AUROC is: ' + "NaN" + '\n')                 #将字符串写入文件中
                    f.write('Train Softmax FPR95 is: ' + "NaN" + '\n')                 #将字符串写入文件中

            else:
                uncertainty_scores_softmax = uncertainty_scores_softmax.cpu().detach().numpy()
                precision, recall, _ = precision_recall_curve(unknown_labels, uncertainty_scores_softmax)
                aupr_score = auc(recall, precision)

                fpr, tpr, _ = roc_curve(unknown_labels, uncertainty_scores_softmax)
                auroc_score_1 = auc(fpr, tpr)

                with open(os.path.join(self.args.save_dir, 'OOD.txt'),'a') as f:    #设置文件对象
                    f.write('Train Softmax Adaptation AUPR is: ' + str(aupr_score) + '\n')                 #将字符串写入文件中
                    f.write('Train Softmax AUROC is: ' + str(auroc_score_1) + '\n')                 #将字符串写入文件中
                    f.write('Train Softmax FPR95 is: ' + str(fpr[tpr > 0.95][0]) + '\n')                 #将字符串写入文件中
                
                self.train_AUPR_list.append(aupr_score)
                self.train_AUROC_list.append(auroc_score_1)
                self.train_FPR95_list.append(fpr[tpr > 0.95][0])

                del uncertainty_scores_softmax
                del unknown_labels

        elif self.args.OOD_type == 'MaxLogit':
            unknown_labels = torch.cat(unknown_label_list, dim=0)
            out = torch.cat(out_list, dim=0)
            unknown_labels = unknown_labels.cpu().numpy()
            out = out.cpu().numpy()
            out = out[unknown_labels != -1]
            unknown_labels = unknown_labels[unknown_labels != -1]

            from sklearn.metrics import precision_recall_curve, auc, roc_curve, roc_auc_score
            
            ood_score = -np.max(out, axis=1)

            precision, recall, _ = precision_recall_curve(unknown_labels, ood_score)
            aupr_score = auc(recall, precision)

            fpr, tpr, _ = roc_curve(unknown_labels, ood_score)
            auroc_score_1 = auc(fpr, tpr)

            with open(os.path.join(self.args.save_dir, 'OOD.txt'),'a') as f:    #设置文件对象
                f.write('MaxLogit Adaptation AUPR is: ' + str(aupr_score))                 #将字符串写入文件中
                f.write('MaxLogit AUROC is: ' + str(auroc_score_1))                 #将字符串写入文件中
                f.write('MaxLogit FPR95 is: ' + str(fpr[tpr > 0.95][0]))                 #将字符串写入文件中
                self.train_AUPR_list.append(aupr_score)
                self.train_AUROC_list.append(auroc_score_1)
                self.train_FPR95_list.append(fpr[tpr > 0.95][0])
            del ood_score
            del unknown_labels

        
        elif self.args.OOD_type == 'Decoupling_MaxLogit':
            unknown_labels = torch.cat(unknown_label_list, dim=0)
            out = torch.cat(out_list, dim=0)
            out_bck = torch.cat(out_bck_list, dim=0)
            unknown_labels = unknown_labels.cpu().numpy()
            out = out.cpu().numpy()

            out_bck = out_bck[unknown_labels != -1]
            out = out[unknown_labels != -1]
            unknown_labels = unknown_labels[unknown_labels != -1]

            from sklearn.metrics import precision_recall_curve, auc, roc_curve, roc_auc_score
            
            all_score1 = np.max(out, axis=1)
            all_score2 = out_bck.norm(2, dim=1).numpy()

            ood_score = all_score1 + all_score2
            ood_score = -ood_score

            precision, recall, _ = precision_recall_curve(unknown_labels, ood_score)
            aupr_score = auc(recall, precision)

            fpr, tpr, _ = roc_curve(unknown_labels, ood_score)
            auroc_score_1 = auc(fpr, tpr)

            with open(os.path.join(self.args.save_dir, 'OOD.txt'),'a') as f:    #设置文件对象
                f.write('Decoupling_MaxLogit Adaptation AUPR is: ' + str(aupr_score))                 #将字符串写入文件中
                f.write('Decoupling_MaxLogit AUROC is: ' + str(auroc_score_1))                 #将字符串写入文件中
                f.write('Decoupling_MaxLogit FPR95 is: ' + str(fpr[tpr > 0.95][0]))                 #将字符串写入文件中
                self.train_AUPR_list.append(aupr_score)
                self.train_AUROC_list.append(auroc_score_1)
                self.train_FPR95_list.append(fpr[tpr > 0.95][0])
            del ood_score
            del unknown_labels

        return adaptation_results

    def online_evaluation_routine(self):
        # move model to device
        self.model.to(self.device)
        # for store
        source_results = []

        unknown_label_list = []
        out_list = []
        out_bck_list = []
        if self.args.OOD_eval_single:
            AUPR_list = []
            AUROC_list = []
            FPR95_list = []

        if self.save_predictions:
            save_path = os.path.join(self.weights_save_path, 'pcd')
        else:
            save_path = None
 
        close_list = []
        ood_list = []
        final_list = []
        with torch.no_grad():
            for f in tqdm(range(len(self.eval_dataset)), desc=f'Seq: {self.sequence}', leave=True):
                # get eval batch
                val_batch = self.get_evaluation_batch(f)
                # eval
                val_dict, unknown_label, out, out_bck, out_bottle = self.pipeline.validation_step(val_batch, is_source=True, save_path=save_path, frame=f)
                val_dict['source/frame'] = f
                # store results
                self.log(val_dict)
                source_results.append(val_dict)


                unknown_label_list.append(unknown_label)
                out_list.append(out)
                if self.args.OOD_type == 'Decoupling_MaxLogit':
                    out_bck_list.append(out_bck)

                if self.args.OOD_eval_test:
                    unknown_label_list.append(unknown_label)
                    out_list.append(out)
                    if self.args.OOD_type == 'Decoupling_MaxLogit':
                        out_bck_list.append(out_bck)


                if self.args.kitti_source_save_npy:
                    save_np_root = os.path.join(self.weights_save_path, 'Source_NPY')
                    os.makedirs(save_np_root, exist_ok=True)
                    save_np_path = os.path.join(save_np_root, str(f) + '.npy')
                    # save_dict = {'coordinates': val_batch['coordinates'].cpu().numpy(), 'global_points': val_batch['global_points'][0].cpu().numpy(), 'labels': val_batch['labels'].cpu().numpy(), 'OOD_labels': val_batch['OOD_labels'].cpu().numpy(), 'out': out.cpu().numpy()}
                    save_dict = {'source_out': out.cpu().numpy()}
                    np.save(save_np_path, save_dict)

                if self.args.nusce_source_save_npy:
                    save_np_root = os.path.join(self.weights_save_path, 'Source_NPY', self.sequence)
                    os.makedirs(save_np_root, exist_ok=True)
                    save_np_path = os.path.join(save_np_root, str(f) + '.npy')
                    save_dict = {'source_out': out.cpu().numpy()}
                    np.save(save_np_path, save_dict)

        if self.args.OOD_type == 'Softmax':
            unknown_labels = torch.cat(unknown_label_list, dim=0)
            out = torch.cat(out_list, dim=0)
            # out_bck = torch.cat(out_bck_list, dim=0)

            if self.args.temp:
                final_list = final_list[unknown_labels != -1]
            out = out[unknown_labels != -1]
            unknown_labels = unknown_labels[unknown_labels != -1]

            from sklearn.metrics import precision_recall_curve, auc, roc_curve, roc_auc_score
            unknown_labels = unknown_labels.cpu().numpy()
            softmax_layer = torch.nn.Softmax(dim=1)
            uncertainty_scores_softmax = 1 - torch.max(softmax_layer(out), dim=1)[0]

            if self.args.temp:
                uncertainty_scores_softmax = 1 - torch.from_numpy(final_list)
        
            if (unknown_labels == 1).sum() == 0:
                with open(os.path.join(self.args.save_dir, 'OOD.txt'),'a') as f:    #设置文件对象
                    f.write('Eval Softmax Adaptation AUPR is: ' + "NaN" + '\n')                 #将字符串写入文件中
                    f.write('Eval Softmax AUROC is: ' + "NaN" + '\n')                 #将字符串写入文件中
                    f.write('Eval Softmax FPR95 is: ' + "NaN" + '\n')                 #将字符串写入文件中

            else:
                uncertainty_scores_softmax = uncertainty_scores_softmax.cpu().detach().numpy()
                precision, recall, _ = precision_recall_curve(unknown_labels, uncertainty_scores_softmax)
                aupr_score = auc(recall, precision)

                fpr, tpr, _ = roc_curve(unknown_labels, uncertainty_scores_softmax)
                auroc_score_1 = auc(fpr, tpr)

                with open(os.path.join(self.args.save_dir, 'OOD.txt'),'a') as f:    #设置文件对象
                    f.write('Eval Softmax Adaptation AUPR is: ' + str(aupr_score) + '\n')                 #将字符串写入文件中
                    f.write('Eval Softmax AUROC is: ' + str(auroc_score_1) + '\n')                 #将字符串写入文件中
                    f.write('Eval Softmax FPR95 is: ' + str(fpr[tpr > 0.95][0]) + '\n')                 #将字符串写入文件中
                self.eval_AUPR_list.append(aupr_score)
                self.eval_AUROC_list.append(auroc_score_1)
                self.eval_FPR95_list.append(fpr[tpr > 0.95][0])
                del uncertainty_scores_softmax
                del unknown_labels            


        elif self.args.OOD_type == 'MaxLogit':
            unknown_labels = torch.cat(unknown_label_list, dim=0)
            out = torch.cat(out_list, dim=0)
            unknown_labels = unknown_labels.cpu().numpy()
            out = out.cpu().numpy()

            out = out[unknown_labels != -1]
            unknown_labels = unknown_labels[unknown_labels != -1]

            from sklearn.metrics import precision_recall_curve, auc, roc_curve, roc_auc_score
            
            ood_score = -np.max(out, axis=1)

            precision, recall, _ = precision_recall_curve(unknown_labels, ood_score)
            aupr_score = auc(recall, precision)

            fpr, tpr, _ = roc_curve(unknown_labels, ood_score)
            auroc_score_1 = auc(fpr, tpr)

            with open(os.path.join(self.args.save_dir, 'OOD.txt'),'a') as f:    #设置文件对象
                f.write('Eval MaxLogit Adaptation AUPR is: ' + str(aupr_score))                 #将字符串写入文件中
                f.write('Eval MaxLogit AUROC is: ' + str(auroc_score_1))                 #将字符串写入文件中
                f.write('Eval MaxLogit FPR95 is: ' + str(fpr[tpr > 0.95][0]))                 #将字符串写入文件中
                self.eval_AUPR_list.append(aupr_score)
                self.eval_AUROC_list.append(auroc_score_1)
                self.eval_FPR95_list.append(fpr[tpr > 0.95][0])
            del ood_score
            del unknown_labels


        elif self.args.OOD_type == 'Decoupling_MaxLogit':
            unknown_labels = torch.cat(unknown_label_list, dim=0)
            out = torch.cat(out_list, dim=0)
            out_bck = torch.cat(out_bck_list, dim=0)
            unknown_labels = unknown_labels.cpu().numpy()
            out = out.cpu().numpy()

            out_bck = out_bck[unknown_labels != -1]
            out = out[unknown_labels != -1]
            unknown_labels = unknown_labels[unknown_labels != -1]

            from sklearn.metrics import precision_recall_curve, auc, roc_curve, roc_auc_score
            
            all_score1 = np.max(out, axis=1)
            all_score2 = out_bck.norm(2, dim=1).numpy()

            ood_score = all_score1 + all_score2
            ood_score = -ood_score

            precision, recall, _ = precision_recall_curve(unknown_labels, ood_score)
            aupr_score = auc(recall, precision)

            fpr, tpr, _ = roc_curve(unknown_labels, ood_score)
            auroc_score_1 = auc(fpr, tpr)

            with open(os.path.join(self.args.save_dir, 'OOD.txt'),'a') as f:    #设置文件对象
                f.write('Eval Decoupling_MaxLogit Adaptation AUPR is: ' + str(aupr_score))                 #将字符串写入文件中
                f.write('Eval Decoupling_MaxLogit AUROC is: ' + str(auroc_score_1))                 #将字符串写入文件中
                f.write('Eval Decoupling_MaxLogit FPR95 is: ' + str(fpr[tpr > 0.95][0]))                 #将字符串写入文件中
                self.eval_AUPR_list.append(aupr_score)
                self.eval_AUROC_list.append(auroc_score_1)
                self.eval_FPR95_list.append(fpr[tpr > 0.95][0])
            del ood_score
            del unknown_labels   

        return source_results

    def set_loggers(self, sequence):
        # set current sequence in loggers, for logging purposes
        for logger in self.loggers:
            logger.set_sequence(sequence)

    def set_sequence(self, sequence):
        # update current weight saving path
        self.sequence = str(sequence)
        path, _ = os.path.split(self.weights_save_path)
        self.weights_save_path = os.path.join(path, self.sequence)
        # os.makedirs(self.weights_save_path, exist_ok=True)

        self.eval_dataset.set_sequence(sequence)
        self.adapt_dataset.set_sequence(sequence)

        if self.boost:
            self.eval_dataloader = iter(self.pipeline.get_online_dataloader(FrameOnlineDataset(self.eval_dataset),
                                                                            is_adapt=False))
            self.adapt_dataloader = iter(self.pipeline.get_online_dataloader(PairedOnlineDataset(self.adapt_dataset,
                                                                                                 use_random=self.pipeline.use_random_wdw),
                                                                                 is_adapt=True))

        # set sequence in path of loggers
        self.set_loggers(sequence)

    def log(self, results_dict):
        # log in ach logger
        for logger in self.loggers:
            logger.log(results_dict)

    def save_state_dict(self, frame):
        # save stat dict of the model
        save_dict = {'frame': frame,
                     'model_state_dict': self.model.state_dict(),
                     'optimizer_state_dict': self.pipeline.optimizer.state_dict()}
        torch.save(save_dict, os.path.join(self.weights_save_path, f'checkpoint-frame{frame}.pth'))

    def reload_model(self, is_adapt=True):
        # reloads model
        def clean_state_dict(state):
            # clean state dict from names of PL
            for k in list(ckpt.keys()):
                if "model" in k:
                    ckpt[k.replace("model.", "")] = ckpt[k]
                del ckpt[k]
            return state

        if self.student_checkpoint is not None and is_adapt:
            checkpoint_path = self.student_checkpoint
            print(f'--> Loading student checkpoint {checkpoint_path}')
        else:
            checkpoint_path = self.source_checkpoint
            print(f'--> Loading source checkpoint {checkpoint_path}')

        # in case of SSL pretraining
        if isinstance(self.model, MinkUNet18_SSL):
            if checkpoint_path.endswith('.pth'):
                ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
                self.model.load_state_dict(ckpt)

            elif checkpoint_path.endswith('.ckpt'):
                ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))["state_dict"]
                ckpt = clean_state_dict(ckpt)
                self.model.load_state_dict(ckpt, strict=True)

            else:
                raise NotImplementedError('Invalid source model extension (allowed .pth and .ckpt)')

        # in case of segmentation pretraining
        elif isinstance(self.model, MinkUNet18_HEADS):
            def clean_student_state_dict(ckpt):
                # clean state dict from names of PL
                for k in list(ckpt.keys()):
                    if "seg_model" in k:
                        ckpt[k.replace("seg_model.", "")] = ckpt[k]
                    del ckpt[k]
                return ckpt
            if checkpoint_path.endswith('.pth'):
                ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))
                ckpt = clean_student_state_dict(ckpt['model_state_dict'])
                self.model.seg_model.load_state_dict(ckpt)

            elif checkpoint_path.endswith('.ckpt'):
                if self.source_checkpoint.endswith('pretrained_model.ckpt'):
                    ckpt = torch.load(self.source_checkpoint, map_location=torch.device('cpu'))
                else:
                    ckpt = torch.load(checkpoint_path, map_location=torch.device('cpu'))["state_dict"]
                ckpt = clean_state_dict(ckpt)
                self.model.seg_model.load_state_dict(ckpt, strict=True)

    def reload_model_from_scratch(self):

        # in case of SSL pretraining
        if isinstance(self.model, MinkUNet18_SSL):
            self.model.weight_initialization()

        # in case of segmentation pretraining
        elif isinstance(self.model, MinkUNet18_HEADS):
            seg_model = self.model.seg_model
            seg_model.weight_initialization()
            self.model = MinkUNet18_HEADS(seg_model=seg_model)

    def load_source_model(self):
        # reloads model
        def clean_state_dict(state):
            # clean state dict from names of PL
            for k in list(ckpt.keys()):
                if "model" in k:
                    ckpt[k.replace("model.", "")] = ckpt[k]
                del ckpt[k]
            return state

        print(f'--> Loading source checkpoint {self.source_checkpoint}')

        if self.source_checkpoint.endswith('.pth'):
            ckpt = torch.load(self.source_checkpoint, map_location=torch.device('cpu'))
            if isinstance(self.source_model, MinkUNet18_MCMC):
                self.source_model.seg_model.load_state_dict(ckpt)
                self.source_model2.seg_model.load_state_dict(ckpt)
            else:
                self.source_model.load_state_dict(ckpt)
                self.source_model2.load_state_dict(ckpt)

        elif self.source_checkpoint.endswith('.ckpt'):
            if self.source_checkpoint.endswith('pretrained_model.ckpt'):
                ckpt = torch.load(self.source_checkpoint, map_location=torch.device('cpu'))
                # for k in list(ckpt.keys()):
                #     ckpt["seg_model."+ k] = ckpt[k]
                #     del ckpt[k]
            else:
                ckpt = torch.load(self.source_checkpoint, map_location=torch.device('cpu'))["state_dict"]
            ckpt = clean_state_dict(ckpt)
            if isinstance(self.source_model, MinkUNet18_MCMC):
                # for key, value in ckpt.items():
                #     print(key)
                # for name,parameters in self.source_model.seg_model.named_parameters():
                #     print(name,':',parameters.size())
                self.source_model.seg_model.load_state_dict(ckpt, strict=True)
                self.source_model2.seg_model.load_state_dict(ckpt, strict=True)
            else:
                self.source_model.load_state_dict(ckpt, strict=True)
                self.source_model2.load_state_dict(ckpt, strict=True)

        else:
            raise NotImplementedError('Invalid source model extension (allowed .pth and .ckpt)')

    def get_adaptation_batch(self, frame_idx):
        if self.adapt_dataloader is None:
            frame_idx += 1
            batch_idx = np.arange(frame_idx - self.pipeline.adaptation_batch_size, frame_idx)

            batch_data = [self.adapt_dataset.__getitem__(b) for b in batch_idx]
            batch_data = [self.adapt_dataset.get_double_data(batch_data[b-1], batch_data[b]) for b in range(1, len(batch_data))]
            batch = self.collate_fn_adapt(batch_data)
        else:
            batch = next(self.adapt_dataloader)

        return batch

    def get_evaluation_batch(self, frame_idx):
        if self.eval_dataloader is None:
            data = self.eval_dataset.__getitem__(frame_idx)
            data = self.eval_dataset.get_single_data(data)

            batch = self.collate_fn_eval([data])
        else:
            batch = next(self.eval_dataloader)

        return batch

    def save_final_results(self):
        # stores final results in a final dict
        # finally saves results in a csv file

        final_dict = {}

        for seq in self.online_sequences:
            source_results = self.source_results_dict[seq]
            adaptation_results = self.adaptation_results_dict[seq]

            assert len(source_results) == len(adaptation_results)
            num_frames = len(source_results)

            source_results = self.format_val_dict(source_results)
            adaptation_results = self.format_val_dict(adaptation_results)

            final_dict[seq] = {}

            for k in adaptation_results.keys():
                relative_tmp = adaptation_results[k] - source_results[k]
                final_dict[seq][f'relative_{k}'] = relative_tmp
                final_dict[seq][f'source_{k}'] = source_results[k]
                final_dict[seq][f'adapted_{k}'] = adaptation_results[k]

        self.write_csv(final_dict, phase='final')
        self.write_csv(final_dict, phase='source')
        self.save_pickle(final_dict)

    def save_eval_results(self):
        # stores final results in a final dict
        # finally saves results in a csv file

        final_dict = {}

        for seq in self.online_sequences:
            eval_results = self.source_results_dict[seq]

            eval_results = self.format_val_dict(eval_results)

            final_dict[seq] = {}

            for k in eval_results.keys():
                final_dict[seq][f'eval_{k}'] = eval_results[k]

        self.write_csv(final_dict, phase='eval')
        self.save_pickle(final_dict)

    def format_val_dict(self, list_dict):
        # input is a list of dicts for each frame
        # returns a dict with [miou, iou_per_frame, per_class_miou, per_class_iou_frame]

        def change_names(in_dict):
            for k in list(in_dict.keys()):
                if "validation/" in k:
                    in_dict[k.replace("validation/", "")] = in_dict[k]
                    del in_dict[k]
                elif "source/" in k:
                    in_dict[k.replace("source/", "")] = in_dict[k]
                    del in_dict[k]

            return in_dict

        list_dict = [change_names(list_dict[f]) for f in range(len(list_dict))]

        if self.num_classes == 7:
            classes = {'vehicle_iou': [],
                       'pedestrian_iou': [],
                       'road_iou': [],
                       'sidewalk_iou': [],
                       'terrain_iou': [],
                       'manmade_iou': [],
                       'vegetation_iou': []}
        elif self.num_classes == 3:
            classes = {'background_iou': [],
                       'vehicle_iou': [],
                       'pedestrian_iou': []}

        elif self.num_classes == 13 and self.args.OOD_poss == True:

            classes = {'person_iou': [],
                       'rider_iou': [],
                       'car_iou': [],
                       'trunk_iou': [],
                       'plants_iou':[],
                        'traffic-sign_iou':[],
                        'pole_iou':[],
                        'garbage-can_iou':[],
                        'building_iou':[],
                        'cone_iou':[],
                        'fence_iou':[],
                        'bike_iou':[],
                        'ground_iou':[]}

        elif self.args.ignore_class is not None:
            classes = {'car_iou': [],
                          'bicycle_iou': [],
                            'motorcycle_iou': [],
                            'truck_iou': [],
                            'other-vehicle_iou': [],
                            'person_iou': [],
                            'bicyclist_iou': [],
                            'motorcyclist_iou': [],
                            'road_iou': [],
                            'parking_iou': [],
                            'sidewalk_iou': [],
                            'other-ground_iou': [],
                            'building_iou': [],
                            'fence_iou': [],
                            'vegetation_iou': [],
                            'trunk_iou': [],
                            'terrain_iou': [],
                            'pole_iou': [],
                            'traffic-sign_iou': []}

            # print(classes)
            # print(len(classes))

            if self.args.ignore_class is not None:
                for key in self.args.ignore_class:
                    key = int(key)
                    ignore_name = self.class2names[key]+ '_iou'
                    classes.pop(ignore_name)
            # print(classes)
            # print(len(classes))
            # asd

        else:
            classes = {'vehicle_iou': [],
                       'pedestrian_iou': []}

        for f in range(len(list_dict)):
            val_tmp = list_dict[f]
            for key in classes.keys():
                if key in val_tmp:
                    classes[key].append(val_tmp[key])
                else:
                    classes[key].append(np.nan)

        all_iou = np.concatenate([np.asarray(v)[np.newaxis, ...] for k, v in classes.items()], axis=0).T

        per_class_iou = np.nanmean(all_iou, axis=0)
        miou = np.nanmean(per_class_iou)

        per_frame_miou = np.nanmean(all_iou, axis=-1)

        return {'miou': miou,
                'per_frame_miou': per_frame_miou,
                'per_class_iou': per_class_iou,
                'per_class_frame_iou': all_iou}

    def write_csv(self, results_dict, phase='final'):
        if self.num_classes == 7:
            if phase == 'final':
                headers = ['sequence', 'relative_miou', 'relative_vehicle_iou',
                           'relative_pedestrian_iou', 'relative_road_iou',
                           'relative_sidewalk_iou', 'relative_terrain_iou',
                           'relative_manmade_iou', 'relative_vegetation_iou']
                file_name = 'final_main.csv'
            elif phase == 'source':
                headers = ['sequence', 'miou', 'source_vehicle_iou',
                           'source_pedestrian_iou', 'source_road_iou',
                           'source_sidewalk_iou', 'source_terrain_iou',
                           'source_manmade_iou', 'source_vegetation_iou']
                file_name = 'source_main.csv'
            elif phase == 'eval':
                headers = ['sequence', 'miou', 'eval_vehicle_iou',
                           'eval_pedestrian_iou', 'eval_road_iou',
                           'eval_sidewalk_iou', 'eval_terrain_iou',
                           'eval_manmade_iou', 'eval_vegetation_iou']
                file_name = 'evaluation_main.csv'
            else:
                raise NotImplementedError
        elif self.num_classes == 3:
            if phase == 'final':
                headers = ['sequence', 'relative_miou',
                           'relative_background_iou',
                           'relative_vehicle_iou',
                           'relative_pedestrian_iou']
                file_name = 'final_main.csv'
            elif phase == 'source':
                headers = ['sequence', 'miou',
                           'source_background_iou',
                           'source_vehicle_iou',
                           'source_pedestrian_iou']
                file_name = 'source_main.csv'
            elif phase == 'eval':
                headers = ['sequence','miou',
                           'source_backround_iou',
                           'eval_vehicle_iou',
                           'eval_pedestrian_iou']
                file_name = 'evaluation_main.csv'
            else:
                raise NotImplementedError
        elif self.num_classes == 13 and self.args.OOD_poss == True:
            if phase == 'final':
                headers = ['sequence', 'relative_miou',
                           'relative_person',
                           'relative_rider',
                           'relative_car',
                           'relative_trunk',
                           'relative_plants',
                           'relative_traffic-sign',
                           'relative_pole',
                           'relative_garbage-can',
                           'relative_building',
                           'relative_cone',
                           'relative_fence',
                           'relative_bike',
                           'relative_ground']
                file_name = 'final_main.csv'
            elif phase == 'source':
                headers = ['sequence', 'miou',
                           'source_person',
                           'source_rider',
                           'source_car',
                           'source_trunk',
                           'source_plants',
                           'source_traffic-sign',
                           'source_pole',
                           'source_garbage-can',
                           'source_building',
                           'source_cone',
                           'source_fence',
                           'source_bike',
                           'source_ground']
                file_name = 'source_main.csv'

            elif phase == 'eval':
                headers = ['sequence','miou',
                           'eval_person',
                           'eval_rider',
                           'eval_car',
                           'eval_trunk',
                           'eval_plants',
                           'eval_traffic-sign',
                           'eval_pole',
                           'eval_garbage-can',
                           'eval_building',
                           'eval_cone',
                           'eval_fence',
                           'eval_bike',
                           'eval_ground']
                file_name = 'evaluation_main.csv'

        elif self.args.ignore_class is not None:
            if phase == 'final':
                headers = ['sequence', 'relative_miou',
                           'relative_car',
                           'relative_bicycle',
                           'relative_motorcycle',
                           'relative_truck',
                           'relative_other-vehicle',
                           'relative_person',
                           'relative_bicyclist',
                           'relative_motorcyclist',
                           'relative_road',
                           'relative_parking',
                           'relative_sidewalk',
                           'relative_other-ground',
                           'relative_building',
                           'relative_fence',
                           'relative_vegetation',
                           'relative_trunk',
                           'relative_terrain',
                           'relative_pole',
                           'relative_traffic-sign']
                           
                if self.args.ignore_class is not None:
                    for key in self.args.ignore_class:
                        key = int(key)
                        ignore_name = 'relative_' + self.class2names[key]
                        headers.remove(ignore_name)
                    
                file_name = 'final_main.csv'
            elif phase == 'source':
                headers = ['sequence', 'miou',
                           'source_car',
                           'source_bicycle',
                           'source_motorcycle',
                           'source_truck',
                           'source_other-vehicle',
                           'source_person',
                           'source_bicyclist',
                           'source_motorcyclist',
                           'source_road',
                           'source_parking',
                           'source_sidewalk',
                           'source_other-ground',
                           'source_building',
                           'source_fence',
                           'source_vegetation',
                           'source_trunk',
                           'source_terrain',
                           'source_pole',
                           'source_traffic-sign']
                
                if self.args.ignore_class is not None:
                    for key in self.args.ignore_class:
                        key = int(key)
                        ignore_name = 'source_' + self.class2names[key]
                        headers.remove(ignore_name)

                file_name = 'source_main.csv'
            elif phase == 'eval':
                headers = ['sequence','miou',
                           'eval_car',
                           'eval_bicycle',
                           'eval_motorcycle',
                           'eval_truck',
                           'eval_other-vehicle',
                           'eval_person',
                           'eval_bicyclist',
                           'eval_motorcyclist',
                           'eval_road',
                           'eval_parking',
                           'eval_sidewalk',
                           'eval_other-ground',
                           'eval_building',
                           'eval_fence',
                           'eval_vegetation',
                           'eval_trunk',
                           'eval_terrain',
                           'eval_pole',
                           'eval_traffic-sign']
                
                if self.args.ignore_class is not None:
                    for key in self.args.ignore_class:
                        key = int(key)
                        ignore_name = 'eval_' + self.class2names[key]
                        headers.remove(ignore_name)

                file_name = 'evaluation_main.csv'
            else:
                raise NotImplementedError
        elif self.num_classes == 2:
            if phase == 'final':
                headers = ['sequence', 'relative_miou',
                           'relative_vehicle_iou',
                           'relative_pedestrian_iou']
                file_name = 'final_main.csv'
            elif phase == 'source':
                headers = ['sequence', 'miou',
                           'source_vehicle_iou',
                           'source_pedestrian_iou']
                file_name = 'source_main.csv'
            elif phase == 'eval':
                headers = ['sequence','miou',
                           'eval_vehicle_iou',
                           'eval_pedestrian_iou']
                file_name = 'evaluation_main.csv'
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError

        if self.dataset_name == 'nuScenes':
            cumul = []

        results_dir = os.path.join(os.path.split(self.weights_save_path)[0], 'final_results')
        os.makedirs(results_dir, exist_ok=True)
        with open(os.path.join(results_dir, file_name), 'w', encoding='UTF8', newline='') as f:
            writer = csv.writer(f)

            # write the header
            writer.writerow(headers)

            for seq in results_dict.keys():
                dict_tmp = results_dict[seq]
                if phase == 'final':
                    per_class = dict_tmp['relative_per_class_iou']
                    if self.num_classes == 7:
                        data = [seq,
                                dict_tmp['relative_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100,
                                per_class[2]*100,
                                per_class[3]*100,
                                per_class[4]*100,
                                per_class[5]*100,
                                per_class[6]*100]
                    elif self.num_classes == 3:
                        data = [seq,
                                dict_tmp['relative_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100,
                                per_class[2]*100]

                    elif self.num_classes == 13:
                        data = [seq,
                                dict_tmp['relative_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100,
                                per_class[2]*100,
                                per_class[3]*100,
                                per_class[4]*100,
                                per_class[5]*100,
                                per_class[6]*100,
                                per_class[7]*100,
                                per_class[8]*100,
                                per_class[9]*100,
                                per_class[10]*100,
                                per_class[11]*100,
                                per_class[12]*100
                        ]

                    elif self.args.ignore_class is not None:
                        data = [seq,
                                dict_tmp['relative_miou']*100,
                        ]

                        if self.args.ignore_class is not None:
                            for i in range(19-len(self.args.ignore_class)):
                                data.append(per_class[i]*100)

                    elif self.num_classes == 2:
                        data = [seq,
                                dict_tmp['relative_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100]

                elif phase == 'source':
                    per_class = dict_tmp['source_per_class_iou']
                    if self.num_classes == 7:
                        data = [seq,
                                dict_tmp['source_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100,
                                per_class[2]*100,
                                per_class[3]*100,
                                per_class[4]*100,
                                per_class[5]*100,
                                per_class[6]*100]
                    elif self.num_classes == 3:
                        data = [seq,
                                dict_tmp['source_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100,
                                per_class[2]*100]

                    elif self.num_classes == 13:
                        data = [seq,
                                dict_tmp['source_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100,
                                per_class[2]*100,
                                per_class[3]*100,
                                per_class[4]*100,
                                per_class[5]*100,
                                per_class[6]*100,
                                per_class[7]*100,
                                per_class[8]*100,
                                per_class[9]*100,
                                per_class[10]*100,
                                per_class[11]*100,
                                per_class[12]*100
                        ]

                    elif self.args.ignore_class is not None:
                        data = [seq,
                                dict_tmp['source_miou']*100,
                        ]
                        if self.args.ignore_class is not None:
                            for i in range(19-len(self.args.ignore_class)):
                                data.append(per_class[i]*100)

                    elif self.num_classes == 2:
                        data = [seq,
                                dict_tmp['source_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100]

                elif phase == 'eval':
                    per_class = dict_tmp['eval_per_class_iou']
                    if self.num_classes == 7:
                        data = [seq,
                                dict_tmp['eval_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100,
                                per_class[2]*100,
                                per_class[3]*100,
                                per_class[4]*100,
                                per_class[5]*100,
                                per_class[6]*100]

                    elif self.num_classes == 3:
                        data = [seq,
                                dict_tmp['eval_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100,
                                per_class[2]*100]

                    elif self.num_classes == 13:
                        data = [seq,
                                dict_tmp['eval_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100,
                                per_class[2]*100,
                                per_class[3]*100,
                                per_class[4]*100,
                                per_class[5]*100,
                                per_class[6]*100,
                                per_class[7]*100,
                                per_class[8]*100,
                                per_class[9]*100,
                                per_class[10]*100,
                                per_class[11]*100,
                                per_class[12]*100
                        ]

                    elif self.args.ignore_class is not None:
                        data = [seq,
                                dict_tmp['eval_miou']*100,
                        ]
                        if self.args.ignore_class is not None:
                            for i in range(19-len(self.args.ignore_class)):
                                data.append(per_class[i]*100)
                    
                    elif self.num_classes == 2:
                        data = [seq,
                                dict_tmp['eval_miou']*100,
                                per_class[0]*100,
                                per_class[1]*100]

                # write the data
                writer.writerow(data)

                if self.dataset_name == 'nuScenes':
                    if phase == 'final':
                        first_iou = dict_tmp['relative_miou']
                    elif phase == 'source':
                        first_iou = dict_tmp['source_miou']
                    elif phase == 'eval':
                        first_iou = dict_tmp['eval_miou']

                    if self.num_classes == 7:
                        cumul.append([first_iou*100,
                                      per_class[0]*100,
                                      per_class[1]*100,
                                      per_class[2]*100,
                                      per_class[3]*100,
                                      per_class[4]*100,
                                      per_class[5]*100,
                                      per_class[6]*100])
                    elif self.num_classes == 3:
                        cumul.append([first_iou*100,
                                      per_class[0]*100,
                                      per_class[1]*100,
                                      per_class[2]*100])
                    elif self.num_classes == 2:
                        cumul.append([first_iou*100,
                                      per_class[0]*100,
                                      per_class[1]*100])

            if self.dataset_name == 'nuScenes':
                avg_cumul = np.array(cumul)
                avg_cumul_tmp = np.nanmean(avg_cumul, axis=0)
                if self.num_classes == 7:
                    data = ['Average',
                            avg_cumul_tmp[0],
                            avg_cumul_tmp[1],
                            avg_cumul_tmp[2],
                            avg_cumul_tmp[3],
                            avg_cumul_tmp[4],
                            avg_cumul_tmp[5],
                            avg_cumul_tmp[6],
                            avg_cumul_tmp[7]]
                elif self.num_classes == 3:
                    data = ['Average',
                            avg_cumul_tmp[0],
                            avg_cumul_tmp[1],
                            avg_cumul_tmp[2]]

                elif self.num_classes == 2:
                    data = ['Average',
                            avg_cumul_tmp[0],
                            avg_cumul_tmp[1]]

                # write cumulative results
                writer.writerow(data)
                seq_locs = np.array([self.adapt_dataset.names2locations[self.adapt_dataset.online_keys[s]] for s in results_dict.keys()])

                for location in ['singapore-queenstown', 'boston-seaport', 'singapore-hollandvillage', 'singapore-onenorth']:
                    valid_sequences = seq_locs == location
                    avg_cumul_tmp = np.nanmean(avg_cumul[valid_sequences, :], axis=0)
                    if self.num_classes == 7:
                        data = [location,
                            avg_cumul_tmp[0],
                            avg_cumul_tmp[1],
                            avg_cumul_tmp[2],
                            avg_cumul_tmp[3],
                            avg_cumul_tmp[4],
                            avg_cumul_tmp[5],
                            avg_cumul_tmp[6],
                            avg_cumul_tmp[7]]

                    elif self.num_classes == 3:
                        data = [location,
                            avg_cumul_tmp[0],
                            avg_cumul_tmp[1],
                            avg_cumul_tmp[2]]

                    elif self.num_classes == 2:
                        data = [location,
                            avg_cumul_tmp[0],
                            avg_cumul_tmp[1]]

                    # write cumulative results
                    writer.writerow(data)

    def save_pickle(self, results_dict):
        results_dir = os.path.join(os.path.split(self.weights_save_path)[0], 'final_results')
        with open(os.path.join(results_dir, 'final_all.pkl'), 'wb') as handle:
            pickle.dump(results_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
