import sys
import torch
import torch.nn.functional as F
import os
import wandb
import pandas as pd
import numpy as np
import warnings
import sklearn.exceptions
import collections
from collections import OrderedDict
import argparse

from utils import fix_randomness, starting_logs, AverageMeter, DictAsObject
from algorithms.algorithms import get_algorithm_class
from models.models import get_backbone_class
from trainers.abstract_trainer import AbstractTrainer

from torchmetrics import Accuracy, AUROC, F1Score
from dataloader.dataloader import data_generator, few_shot_data_generator, data_generator_bound, data_generator_baseline, get_data_transforms
from configs.data_model_configs import get_dataset_class
from configs.hparams import get_hparams_class
from configs.sweep_params import sweep_alg_hparams

warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)

parser = argparse.ArgumentParser()

class Trainer_MoSSDA(AbstractTrainer):
    """
    This class contain the main training function for MoSSDA
    """

    def __init__(self, args):
        super().__init__(args)

        self.results_columns = ['scenario', 'run', 'acc', 'f1_score', 'auroc']
        self.risks_columns = ['scenario', 'run', 'src_risk', 'few_shot_risk', 'trg_risk']

    def train_model(self):
        # Get the algorithm and the backbone network SSDA_classifier
        algorithm_class = get_algorithm_class(self.da_method)
        backbone_fe = get_backbone_class(self.backbone)

        self.algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.num_epochs, self.post_epochs, self.hparams, self.device)
        self.algorithm.to(self.device)

        # Training the model
        self.last_head, self.best_head, self.last_classifier, self.best_classifier = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger)

        return self.last_head, self.last_classifier, self.best_head, self.best_classifier

    def evaluate(self, test_loader):
        feature_extractor = self.algorithm.feature_extractor.to(self.device)
        head = self.algorithm.projection_head.to(self.device)
        classifier = self.algorithm.classifier.to(self.device)

        feature_extractor.eval()
        head.eval()
        classifier.eval()

        total_loss, preds_list, labels_list = [], [], []

        with torch.no_grad():
            for data, labels, _ in test_loader:
                data = data.float().cuda(non_blocking =True)
                labels = labels.view((-1)).long().cuda(non_blocking=True)

                # forward pass
                features = feature_extractor(data)
                out_feat = head(features)
                predictions = classifier(out_feat)

                # compute loss
                loss = F.cross_entropy(predictions, labels)
                total_loss.append(loss.item())
                pred = predictions.detach()

                # append predictions and labels
                preds_list.append(pred)
                labels_list.append(labels)
                
        self.loss = torch.tensor(total_loss).mean()
        self.full_preds = torch.cat((preds_list))
        self.full_labels = torch.cat((labels_list))


    def load_data(self,src_id, trg_id):
        
        self.src_train_dl = data_generator(self.data_path, src_id, self.dataset_configs, self.hparams, "train", "source", self.unlabeled_ratio)
        self.src_test_dl = data_generator(self.data_path, src_id, self.dataset_configs, self.hparams,"test", "source", self.unlabeled_ratio)

        self.trg_train_dl = data_generator(self.data_path, trg_id, self.dataset_configs, self.hparams, "train", "target", self.unlabeled_ratio)
        self.trg_test_dl = data_generator(self.data_path, trg_id, self.dataset_configs, self.hparams, "test", "target", self.unlabeled_ratio)
        self.few_shot_dl_5 = few_shot_data_generator(self.trg_test_dl, self.dataset_configs, 5) # set 5 to other value if you want other k-shot FST

    def create_save_dir(self, save_dir):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

    def save_tables_to_file(self, table_results, name):
        # save to file if needed
        table_results.to_csv(os.path.join(self.exp_log_dir, f"{self.run_description}_{name}.csv"))

    def load_checkpoint(self, model_dir):
        checkpoint = torch.load(os.path.join(self.home_path, model_dir, 'checkpoint.pt'), map_location=self.device)
        last_head = checkpoint['last_head']
        last_classifier = checkpoint['last_classifier']
        best_head = checkpoint['best_head']
        best_classifier = checkpoint['best_classifier']

        return last_head, last_classifier, best_head, best_classifier

    def save_checkpoint(self, home_path, log_dir, last_head, best_head, last_classifier, best_classifier):
        save_dict = {
            "last_head" : last_head,
            "last_classifier" : last_classifier,
            "best_head" : best_head,
            "best_classifier" : best_classifier
        }
        # save classification report
        save_path = os.path.join(home_path, log_dir, f"checkpoint.pt")
        torch.save(save_dict, save_path)

    def fit(self):
        """Trainer for MoSSDA"""
        table_results = pd.DataFrame(columns = self.results_columns)
        table_risks = pd.DataFrame(columns = self.risks_columns)

        for src_id, trg_id in self.dataset_configs.scenarios:
            for run_id in range(self.num_runs):
                
                fix_randomness(run_id)

                self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir, src_id, trg_id, run_id)
                self.loss_avg_meters = collections.defaultdict(lambda:AverageMeter())
                self.load_data(src_id, trg_id)

                # initiate the domain adaptation algorithm
                self.initialize_algorithm()

                self.last_head, self.best_head, self.last_classifier, self.best_classifier = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger)

                self.save_checkpoint(self.home_path, self.scenario_log_dir, self.last_head, self.best_head, self.last_classifier, self.best_classifier)

                metrics = self.calculate_metrics()
                risks = self.calculate_risks()

                scenario = f"{src_id}_to_{trg_id}"
                table_results = self.append_results_to_tables(table_results, scenario, run_id, metrics)
                table_risks = self.append_results_to_tables(table_risks, scenario, run_id, risks)

        # Calculate and appedn mean and std to tables
        table_results = self.add_mean_std_table(table_results, self.results_columns)
        table_risks = self.add_mean_std_table(table_risks, self.risks_columns)

        # Save tables to file if needed
        self.save_tables_to_file(table_results, "results")
        self.save_tables_to_file(table_risks, "risks")

    def test(self):
        # Results dataframes
        last_results = pd.DataFrame(columns = self.results_columns)
        best_results = pd.DataFrame(columns = self.results_columns)

        # Cross-domain scenarios
        for src_id, trg_id in self.dataset_configs.scenarios:
            for run_id in range(self.num_runs):
                
                fix_randomness(run_id)
                
                self.scenario_log_dir = os.path.join(self.exp_log_dir, src_id +"_to_"+trg_id+'_run_'+str(run_id))
                self.loss_avg_meters = collections.defaultdict(lambda: AverageMeter())

                self.load_data(src_id, trg_id)
                self.initialize_algorithm()
                last_head_chk, last_classifier_chk, best_head_chk, best_classifier_chk = self.load_checkpoint(self.scenario_log_dir)

                # Testing the last model
                self.algorithm.network.load_state_dict(last_head_chk)
                self.algorithm.classifier.load_state_dict(last_classifier_chk)
                self.evaluate(self.trg_test_dl)
                last_metrics = self.calculate_metrics()
                last_results = self.append_results_to_tables(last_results, f"{src_id}_to_{trg_id}", run_id, last_metrics)

                # Testing the best model
                self.algorithm.network.load_state_dict(best_head_chk)
                self.algorithm.classifier.load_state_dict(best_classifier_chk)
                self.evaluate(self.trg_test_dl)
                best_metrics = self.calculate_metrics()
                best_results = self.append_results_to_tables(best_results, f"{src_id}_to_{trg_id}", run_id, last_metrics)

        last_scenario_mean_std = last_results.groupby('scenario')[['acc','f1_score','auroc']].agg(['mean','std'])
        best_scenario_mean_std = best_results.groupby('scenario')[['acc','f1_score','auroc']].agg(['mean','std'])

        # Save tables to file if needed
        self.save_tables_to_file(last_scenario_mean_std, 'last_results')
        self.save_tables_to_file(best_scenario_mean_std, 'best_resutls')

        # printing summary
        summary_last = {metric: np.mean(last_results[metric]) for metric in self.results_columns[2:]}
        summary_best = {metric: np.mean(best_results[metric]) for metric in self.results_columns[2:]}
        for summary_name, summary in [('Last', summary_last), ('Best', summary_best)]:
            for key, val in summary.items():
                print(f'{summary_name} : {key}\t: {val:2.4f}')


class Trainer_Baselines(AbstractTrainer):
    """
    This class contain the main training function for Baselines
    """
    def __init__(self, args):
        super().__init__(args)
        self.results_columns = ['scenario', 'run', 'acc', 'f1_score', 'auroc']
        self.risks_columns = ['scenario', 'run', 'src_risk', 'few_shot_risk', 'trg_risk']

        # Specify number of hparams
        self.hparams = {**self.hparams_class.alg_hparams[self.da_method],
                        **self.hparams_class.base_params, **self.hparams_class.train_params,}

    def train_model(self):
        # Get the algorithm and the backbone network
        algorithm_class = get_algorithm_class(self.da_method)
        backbone_fe = get_backbone_class(self.backbone)

        self.algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.num_epochs, self.post_epochs, self.hparams, self.device)
        self.algorithm.to(self.device)

        # Training the model
        self.last_head, self.best_head, self.last_classifier, self.best_classifier = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger)

        return self.last_head, eslf.last_classifier, self.best_head, self.best_classifier

    def load_data(self, src_id, trg_id, rot=False):
        
        data_transforms = get_data_transforms(self.dataset_configs)
        if self.da_method == 'MoSSDA_all':
            self.src_train_dl = data_generator(self.data_path, src_id, self.dataset_configs, self.hparams, "train", "source", self.unlabeled_ratio)
            self.trg_train_dl = data_generator(self.data_path, src_id, self.dataset_configs, self.hparams, "train", "target", self.unlabeled_ratio)

        else:            
            self.src_train_dl = data_generator_baseline(self.data_path, src_id, self.dataset_configs, self.hparams,
                                                       dtype="train", domain="source", transform = data_transforms["train"],
                                                       strong_transform=data_transforms["strong"], unlabeled_ratio=self.unlabeled_ratio, rot=rot)
            self.trg_train_dl = data_generator_baseline(self.data_path, src_id, self.dataset_configs, self.hparams,
                                                   dtype="train", domain="target", transform = data_transforms["train"],
                                                   strong_transform=data_transforms["strong"], unlabeled_ratio=self.unlabeled_ratio, rot=rot)

        self.src_test_dl = data_generator(self.data_path, src_id, self.dataset_configs, self.hparams, "test", "source", self.unlabeled_ratio)
        self.trg_test_dl = data_generator(self.data_path, trg_id, self.dataset_configs, self.hparams, "test", "target", self.unlabeled_ratio)
        self.few_shot_dl_5 = few_shot_data_generator(self.trg_test_dl, self.dataset_configs, 5)

    def evaluate_(self, test_loader):
        feature_extractor = self.algorithm.feature_extractor.to(self.device)
        head = self.algorithm.projection_head.to(self.device)
        classifier = self.algorithm.classifier.to(self.device)

        feature_extractor.eval()
        head.eval()
        classifier.eval()

        total_loss, preds_list, labels_list = [], [], []

        with torch.no_grad():
            for data, labels, _ in test_loader:
                data = data.float().cuda(non_blocking =True)
                labels = labels.view((-1)).long().cuda(non_blocking=True)

                # forward pass
                features = feature_extractor(data)
                out_feat = head(features)
                predictions = classifier(out_feat)

                # compute loss
                loss = F.cross_entropy(predictions, labels)
                total_loss.append(loss.item())
                pred = predictions.detach()

                # append predictions and labels
                preds_list.append(pred)
                labels_list.append(labels)
                
        self.loss = torch.tensor(total_loss).mean()
        self.full_preds = torch.cat((preds_list))
        self.full_labels = torch.cat((labels_list))
        
    def evaluate(self, test_loader):
        feature_extractor = self.algorithm.feature_extractor.to(self.device)
        classifier = self.algorithm.classifier.to(self.device)

        feature_extractor.eval()
        classifier.eval()

        total_loss, preds_list, labels_list = [], [], []

        with torch.no_grad():
            for data, labels, _ in test_loader:
                data = data.float().cuda(non_blocking=True)
                labels = labels.view((-1)).long().cuda(non_blocking=True)

                # forward pass
                features = feature_extractor(data)
                predictions = classifier(features)

                # compute loss
                loss = F.cross_entropy(predictions, labels)
                total_loss.append(loss.item())
                pred = predictions.detach()

                # append predictions and labels
                preds_list.append(pred)
                labels_list.append(labels)

        self.loss = torch.tensor(total_loss).mean()
        self.full_preds = torch.cat((preds_list))
        self.full_labels = torch.cat((labels_list))

    def create_save_dir(self, save_dir):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

    def save_tables_to_file(self, table_results, name):
        # save to file if needed
        table_results.to_csv(os.path.join(self.exp_log_dir, f"{self.run_description}_{name}.csv"))

    def load_checkpoint(self, model_dir):
        checkpoint = torch.load(os.path.join(self.home_path, model_dir, 'checkpoint.pt'), map_location=self.device)
        last_head = checkpoint['last_head']
        last_classifier = checkpoint['last_classifier']
        best_head = checkpoint['best_head']
        best_classifier = checkpoint['best_classifier']

        return last_head, last_classifier, best_head, best_classifier

    def save_checkpoint(self, home_path, log_dir, last_head, best_head, last_classifier, best_classifier):
        save_dict = {
            "last_head" : last_head,
            "last_classifier" : last_classifier,
            "best_head" : best_head,
            "best_classifier" : best_classifier
        }
        # save classification report
        save_path = os.path.join(home_path, log_dir, f"checkpoint.pt")
        torch.save(save_dict, save_path)

    def fit(self):
        """Trainer for Baselines"""
        table_results = pd.DataFrame(columns = self.results_columns)
        table_risks = pd.DataFrame(columns = self.risks_columns)

        for src_id, trg_id in self.dataset_configs.scenarios:
            for run_id in range(self.num_runs):
                fix_randomness(run_id)
                self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir, src_id, trg_id, run_id)
                self.loss_avg_meters = collections.defaultdict(lambda:AverageMeter())
                rot = True if self.da_method == 'PAC' else False
                self.load_data(src_id, trg_id, rot)

                # initiate the domain adaptation algorithm
                self.initialize_algorithm()
                
                self.last_head, self.best_head, self.last_classifier, self.best_classifier = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger)

                self.save_checkpoint(self.home_path, self.scenario_log_dir, self.last_head, self.best_head, self.last_classifier, self.best_classifier)

                metrics = self.calculate_metrics()
                risks = self.calculate_risks()

                scenario = f"{src_id}_to_{trg_id}"
                table_results = self.append_results_to_tables(table_results, scenario, run_id, metrics)
                table_risks = self.append_results_to_tables(table_risks, scenario, run_id, risks)

        # calculate and append mean and std to tables
        table_results = self.add_mean_std_table(table_results, self.results_columns)
        table_risks = self.add_mean_std_table(table_risks, self.risks_columns)

        # Save tables to file if needed
        self.save_tables_to_file(table_results, 'results')
        self.save_tables_to_file(table_risks, 'risks')

    def test(self):
        # Results dataframe
        last_results = pd.DataFrame(columns = self.results_columns)
        best_results = pd.DataFrame(columns = self.results_columns)

        # Cross-domain scenarios
        for src_id, trg_id in self.dataset_configs.scenarios:
            for run_id in range(self.num_runs):
                fix_randomness(run_id)
                self.scenario_log_dir = os.path.join(self.exp_log_dir, src_id+'_to_'+trg_id+'_run_'+str(run_id))
                self.loss_avg_meters = collections.defaultdict(lambda:AverageMeter())

                self.load_data(src_id, trg_id)
                self.initialize_algorithm()
                last_head_chk, last_classifier_chk, best_head_chk, best_classifier_chk = self.load_checkpoint(self.scenario_log_dir)
                
                # Testing the last model
                
                if self.da_method == 'MoSSDA_all':
                    self.algorithm.network.load_state_dict(last_head_chk)
                    self.algorithm.classifier.load_state_dict(last_classifier_chk)                    
                    self.evaluate_(self.trg_test_dl)
                else:
                    self.algorithm.feature_extractor.load_state_dict(last_head_chk)
                    self.algorithm.classifier.load_state_dict(last_classifier_chk)
                    self.evaluate(self.trg_test_dl)
                last_metrics = self.calculate_metrics()
                last_results = self.append_results_to_tables(last_results, f"{src_id}_to_{trg_id}", run_id, last_metrics)

                # Testing the best model
                if self.da_method == 'MoSSDA_all':
                    self.algorithm.network.load_state_dict(best_head_chk)
                    self.algorithm.classifier.load_state_dict(best_classifier_chk)
                    self.evaluate_(self.trg_test_dl)
                else:
                    self.algorithm.feature_extractor.load_state_dict(best_head_chk)
                    self.algorithm.classifier.load_state_dict(best_classifier_chk)
                    self.evaluate(self.trg_test_dl)
                    
                best_metrics = self.calculate_metrics()
                best_results = self.append_results_to_tables(best_results, f"{src_id}_to_{trg_id}", run_id, last_metrics)

        last_scenario_mean_std = last_results.groupby('scenario')[['acc','f1_score','auroc']].agg(['mean','std'])
        best_scenario_mean_std = best_results.groupby('scenario')[['acc','f1_score','auroc']].agg(['mean','std'])

        # Save tables to file if needed
        self.save_tables_to_file(last_scenario_mean_std, 'last_results')
        self.save_tables_to_file(best_scenario_mean_std, 'best_resutls')

        # printing summary
        summary_last = {metric: np.mean(last_results[metric]) for metric in self.results_columns[2:]}
        summary_best = {metric: np.mean(best_results[metric]) for metric in self.results_columns[2:]}
        for summary_name, summary in [('Last', summary_last), ('Best', summary_best)]:
            for key, val in summary.items():
                print(f'{summary_name} : {key}\t: {val:2.4f}')


class Trainer_boundary(AbstractTrainer):
    """
    This class contain the main training function for MoSSDA NoAdapt, TargetOnly
    """

    def __init__(self, args):
        super().__init__(args)

        self.results_columns = ['scenario', 'run', 'acc', 'f1_score', 'auroc']
        self.risks_columns = ['scenario', 'run', 'src_risk', 'few_shot_risk', 'trg_risk']

    def train_model(self):
        # Get the algorithm and the backbone network
        algorithm_class = get_algorithm_class(self.da_method)
        backbone_fe = get_backbone_class(self.backbone)

        self.algorithm = algorithm_class(backbone_fe, self.dataset_configs, self.num_epochs, self.post_epochs, self.hparams, self.device)
        self.algorithm.to(self.device)

        # Training the model
        self.last_head, self.best_head, self.last_classifier, self.best_classifier = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger)

        return self.last_head, self.last_classifier, self.best_head, self.best_classifier

    def evaluate(self, test_loader):
        network = self.algorithm.feature_extractor.to(self.device)
        classifier = self.algorithm.classifier.to(self.device)

        network.eval()
        classifier.eval()

        total_loss, preds_list, labels_list = [], [], []

        with torch.no_grad():
            for data, labels, _ in test_loader:
                data = data.float().cuda(non_blocking =True)
                labels = labels.view((-1)).long().cuda(non_blocking=True)

                # forward pass
                out_feat = network(data)
                predictions = classifier(out_feat)

                # compute loss
                loss = F.cross_entropy(predictions, labels)
                total_loss.append(loss.item())
                pred = predictions.detach()

                # append predictions and labels
                preds_list.append(pred)
                labels_list.append(labels)
                
        self.loss = torch.tensor(total_loss).mean()
        self.full_preds = torch.cat((preds_list))
        self.full_labels = torch.cat((labels_list))


    def load_data(self,src_id, trg_id):
        
        self.src_train_dl = data_generator_bound(self.data_path, src_id, self.dataset_configs, self.hparams, "train", "source", self.unlabeled_ratio)
        self.src_test_dl = data_generator_bound(self.data_path, src_id, self.dataset_configs, self.hparams,"test", "source", self.unlabeled_ratio)

        self.trg_train_dl = data_generator_bound(self.data_path, trg_id, self.dataset_configs, self.hparams, "train", "target", self.unlabeled_ratio)
        self.trg_test_dl = data_generator_bound(self.data_path, trg_id, self.dataset_configs, self.hparams, "test", "target", self.unlabeled_ratio)
        self.few_shot_dl_5 = few_shot_data_generator(self.trg_test_dl, self.dataset_configs, 5) # set 5 to other value if you want other k-shot FST


    def create_save_dir(self, save_dir):
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

    def save_tables_to_file(self, table_results, name):
        # save to file if needed
        table_results.to_csv(os.path.join(self.exp_log_dir, f"{self.run_description}_{name}.csv"))

    def load_checkpoint(self, model_dir):
        checkpoint = torch.load(os.path.join(self.home_path, model_dir, 'checkpoint.pt'), map_location=self.device)
        last_head = checkpoint['last_head']
        last_classifier = checkpoint['last_classifier']
        best_head = checkpoint['best_head']
        best_classifier = checkpoint['best_classifier']

        return last_head, last_classifier, best_head, best_classifier

    def save_checkpoint(self, home_path, log_dir, last_head, best_head, last_classifier, best_classifier):
        save_dict = {
            "last_head" : last_head,
            "last_classifier" : last_classifier,
            "best_head" : best_head,
            "best_classifier" : best_classifier
        }
        # save classification report
        save_path = os.path.join(home_path, log_dir, f"checkpoint.pt")
        torch.save(save_dict, save_path)

    def fit(self):
        """Trainer for MoSSDA"""
        table_results = pd.DataFrame(columns = self.results_columns)
        table_risks = pd.DataFrame(columns = self.risks_columns)

        for src_id, trg_id in self.dataset_configs.scenarios:
            for run_id in range(self.num_runs):
                
                fix_randomness(run_id)

                self.logger, self.scenario_log_dir = starting_logs(self.dataset, self.da_method, self.exp_log_dir, src_id, trg_id, run_id)
                self.loss_avg_meters = collections.defaultdict(lambda:AverageMeter())
                self.load_data(src_id, trg_id)

                # initiate the domain adaptation algorithm
                self.initialize_algorithm()

                self.last_head, self.best_head, self.last_classifier, self.best_classifier = self.algorithm.update(self.src_train_dl, self.trg_train_dl, self.loss_avg_meters, self.logger)

                self.save_checkpoint(self.home_path, self.scenario_log_dir, self.last_head, self.best_head, self.last_classifier, self.best_classifier)

                metrics = self.calculate_metrics()
                risks = self.calculate_risks()

                scenario = f"{src_id}_to_{trg_id}"
                table_results = self.append_results_to_tables(table_results, scenario, run_id, metrics)
                table_risks = self.append_results_to_tables(table_risks, scenario, run_id, risks)

        # Calculate and appedn mean and std to tables
        table_results = self.add_mean_std_table(table_results, self.results_columns)
        table_risks = self.add_mean_std_table(table_risks, self.risks_columns)

        # Save tables to file if needed
        self.save_tables_to_file(table_results, "results")
        self.save_tables_to_file(table_risks, "risks")

    def test(self):
        # Results dataframes
        last_results = pd.DataFrame(columns = self.results_columns)
        best_results = pd.DataFrame(columns = self.results_columns)

        # Cross-domain scenarios
        for src_id, trg_id in self.dataset_configs.scenarios:
            for run_id in range(self.num_runs):
                
                fix_randomness(run_id)
                
                self.scenario_log_dir = os.path.join(self.exp_log_dir, src_id +"_to_"+trg_id+'_run_'+str(run_id))
                self.loss_avg_meters = collections.defaultdict(lambda: AverageMeter())

                self.load_data(src_id, trg_id)
                self.initialize_algorithm()
                last_head_chk, last_classifier_chk, best_head_chk, best_classifier_chk = self.load_checkpoint(self.scenario_log_dir)

                # Testing the last model
                self.algorithm.network.load_state_dict(last_head_chk)
                self.algorithm.classifier.load_state_dict(last_classifier_chk)
                self.evaluate(self.trg_test_dl)
                last_metrics = self.calculate_metrics()
                last_results = self.append_results_to_tables(last_results, f"{src_id}_to_{trg_id}", run_id, last_metrics)

                # Testing the best model
                self.algorithm.network.load_state_dict(best_head_chk)
                self.algorithm.classifier.load_state_dict(best_classifier_chk)
                self.evaluate(self.trg_test_dl)
                best_metrics = self.calculate_metrics()
                best_results = self.append_results_to_tables(best_results, f"{src_id}_to_{trg_id}", run_id, last_metrics)

        last_scenario_mean_std = last_results.groupby('scenario')[['acc','f1_score','auroc']].agg(['mean','std'])
        best_scenario_mean_std = best_results.groupby('scenario')[['acc','f1_score','auroc']].agg(['mean','std'])

        # Save tables to file if needed
        self.save_tables_to_file(last_scenario_mean_std, 'last_results')
        self.save_tables_to_file(best_scenario_mean_std, 'best_results')

        # printing summary
        summary_last = {metric: np.mean(last_results[metric]) for metric in self.results_columns[2:]}
        summary_best = {metric: np.mean(best_results[metric]) for metric in self.results_columns[2:]}
        for summary_name, summary in [('Last', summary_last), ('Best', summary_best)]:
            for key, val in summary.items():
                print(f'{summary_name} : {key}\t: {val:2.4f}')