import torch
import numpy as np

class Logger_classify(object):
    """ logger for node classification task, reporting train/valid/test accuracy or rocauc for classification """
    def __init__(self, info=None):
        self.info = info
        self.results = []

    def add_result(self, result):
        assert len(result) == 4
        self.results.append(result)

    def print_statistics(self, run=None):
        metrics = {}
        str = ""
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
            argmin = result[:, 3].argmin().item()
            str += f'Run {run + 1:02d}:\n'
            str += f'Highest Train: {result[:, 0].max():.2f}\n'
            str += f'Highest Valid: {result[:, 1].max():.2f}\n'
            str += f'Highest Test: {result[:, 2].max():.2f}\n'
            str += f'Chosen epoch: {argmax+1}\n'
            str += f'Final Train: {result[argmax, 0]:.2f}\n'
            str += f'Final Test: {result[argmax, 2]:.2f}\n'
            self.test=result[argmax, 2]
            metrics['run'] = run + 1
            metrics['classification_acc_highest_train'] = result[:, 0].max()
            metrics['classification_acc_highest_valid'] = result[:, 1].max()
            metrics['classification_acc_highest_test'] = result[:, 2].max()
            metrics['classification_chosen_epoch'] = argmax + 1
            metrics['classification_final_train'] = result[argmax, 0]
            return str, metrics
        else:
            result = 100 * torch.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                test1 = r[:, 2].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test2 = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, test1, valid, train2, test2))

            best_result = torch.tensor(best_results)

            if best_result.shape[0] == 1:
                str += f'All runs:\n'
                r = best_result[:, 0]
                str += f'Highest Train: {r.mean():.2f}\n'
                metrics['classification_acc_highest_train'] = r.mean()
                r = best_result[:, 1]
                str += f'Highest Test: {r.mean():.2f}\n'
                metrics['classification_acc_highest_test'] = r.mean()
                r = best_result[:, 2]
                str += f'Highest Valid: {r.mean():.2f}\n'
                metrics['classification_acc_highest_valid'] = r.mean()
                r = best_result[:, 3]
                str += f'  Final Train: {r.mean():.2f}\n'
                metrics['classification_final_train'] = r.mean()
                r = best_result[:, 4]
                str += f'   Final Test: {r.mean():.2f}\n'
                metrics['classification_final_test'] = r.mean()
            else:
                str += f'All runs:\n'
                r = best_result[:, 0]
                str += f'Highest Train: {r.mean():.2f} ± {r.std():.2f}\n'
                metrics['classification_acc_highest_train'] = r.mean()
                r = best_result[:, 1]
                str += f'Highest Test: {r.mean():.2f} ± {r.std():.2f}\n'
                metrics['classification_acc_highest_test'] = r.mean()
                r = best_result[:, 2]
                str += f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}\n'
                metrics['classification_acc_highest_valid'] = r.mean()
                r = best_result[:, 3]
                str += f'  Final Train: {r.mean():.2f} ± {r.std():.2f}\n'
                metrics['classification_final_train'] = r.mean()
                r = best_result[:, 4]
                str += f'   Final Test: {r.mean():.2f} ± {r.std():.2f}\n'
                metrics['classification_final_test'] = r.mean()
            self.test=r.mean()
            return str, metrics, best_result[:, 4]
    
    def output(self,out_path,info):
        with open(out_path,'a') as f:
            f.write(info)
            f.write(f'test acc:{self.test}\n')


# change the print_statistics to return a dict and best_result
class Logger_detect(object):
    """ logger for ood detection task, reporting test auroc/aupr/fpr95 for ood detection """
    def __init__(self, info=None):
        self.info = info
        self.results = []

    def add_result(self, result):
        assert len(result) % 3 == 0
        self.results.append(result)


    def print_statistics(self):
        metrics = {}
        str = ""
        result = 100 * torch.tensor(self.results) # (200, 9)
        best_epoch = result[:, -1].argmin()       
        best_result = result[best_epoch]  
        str = ""
        str += f'END_AUROC: {best_result[0]:.2f}\n'
        metrics[f'END_AUROC'] = best_result[0]   
        str += f'END_AUPR_out: {best_result[1]:.2f}\n'
        metrics[f'END_AUPR_out'] = best_result[1]
        str += f'END_FPR95: {best_result[2]:.2f}\n'
        metrics[f'END_FPR95'] = best_result[2]
    
        return str, metrics, best_result

# def save_result(results, cfg):
#     if cfg["dataset"] in ('cora', 'amazon-photo', 'coauthor-cs'):
#         filename = f'../results/{cfg["dataset"]}-{cfg["ood_type"]}.csv'
#     else:
#         filename = f'../results/{cfg["dataset"]}.csv'

#     if cfg["model"] == 'gnnsafe':
#         if cfg["use_prop"]:
#             name = 'gnnsafe++' if cfg["use_reg"] else 'gnnsafe'
#         else:
#             name = 'gnnsafe++ w/o prop' if cfg["use_reg"] else 'gnnsafe w/o prop'
#     else:
#         name = cfg["model"]

#     print(f"Saving results to {filename}")
#     with open(f"{filename}", 'a+') as write_obj:
#         write_obj.write(f"{name} {cfg['name']}\n")
#         if cfg["print_args"]:
#             write_obj.write(f'{cfg}\n')
#         auroc, aupr, fpr = [], [], []
#         for k in range(results.shape[1] // 3):
#             r = results[:, k * 3]
#             auroc.append(r.mean())
#             write_obj.write(f'OOD Test {k + 1} Final AUROC: {r.mean():.2f} ')
#             r = results[:, k * 3 + 1]
#             aupr.append(r.mean())
#             write_obj.write(f'OOD Test {k + 1} Final AUPR: {r.mean():.2f} ')
#             r = results[:, k * 3 + 2]
#             fpr.append(r.mean())
#             write_obj.write(f'OOD Test {k + 1} Final FPR: {r.mean():.2f}\n')
#         if k > 0: # for multiple OODTe datasets, return the averaged metrics
#             write_obj.write(f'OOD Test Averaged Final AUROC: {np.mean(auroc):.2f} ')
#             write_obj.write(f'OOD Test Averaged Final AUPR: {np.mean(aupr):.2f} ')
#             write_obj.write(f'OOD Test Averaged Final FPR: {np.mean(fpr):.2f}\n')
#         r = results[:, -1]
#         write_obj.write(f'IND Test Score: {r.mean():.2f}\n')




class Logger_ood(object):
    
    def __init__(self, info=None) -> None:
        self.info = info
        self.results = []
    def add_result(self, result):
        self.results.append(result)

    def print_statistics(self):
        metrics = {}
        str = ""
        
        result = 100 * torch.tensor(self.results)
        
        # Find best values for each metric
        best_auroc = result[:, 6].max().item()
        best_aupr_in = result[:, 7].max().item()
        best_aupr_out = result[:, 8].max().item()
        best_fpr95 = result[:, 9].min().item()  # FPR is better when lower
        best_detection_acc = result[:, 10].max().item()
        best_test_acc = result[:, 5].max().item()
        
        str += "----------------END RESULT----------------  "
        str += f'END Final Test: {best_test_acc:.2f}\n'
        metrics["END_Test_acc"] = best_test_acc
        self.test_acc = best_test_acc
        str += 'Detection Task\n'
        str += f'END_AUROC: {best_auroc:.2f}\n'
        metrics['END_AUROC'] = best_auroc
        str += f'END_AUPR_in: {best_aupr_in:.2f}\n'
        metrics['END_AUPR_in'] = best_aupr_in
        str += f'END_AUPR_out: {best_aupr_out:.2f}\n'
        metrics['END_AUPR_out'] = best_aupr_out
        str += f'END_FPR95: {best_fpr95:.2f}\n'
        metrics['END_FPR95'] = best_fpr95
        str += f'END_DETECTION_acc: {best_detection_acc:.2f}\n'
        metrics['END_DETECTION_acc'] = best_detection_acc
        
        self.auc = best_auroc
        self.aupr_in = best_aupr_in
        self.aupr_out = best_aupr_out
        self.fpr95 = best_fpr95
        self.detection_acc = best_detection_acc

        return str, metrics, [self.test_acc, \
            self.auc, self.aupr_in, self.aupr_out, self.fpr95, self.detection_acc]



class Logger_misclassification(object):
    
    def __init__(self, info=None) -> None:
        self.info = info
        self.results = []
    def add_result(self, result):
        self.results.append(result)

    def print_statistics(self):
        metrics = {}
        str = ""
        
        result = 100 * torch.tensor(self.results)
        
        # Find best values for each metric
        best_auroc = result[:, 4].max().item()
        best_aupr_cor = result[:, 5].max().item()
        best_aupr_err = result[:, 6].max().item()
        best_fpr = result[:, 7].min().item()  # FPR is better when lower
        best_test_acc = result[:, 2].max().item()
        
        str += "----------------END RESULT----------------  "
        str += f'END Final Test: {best_test_acc:.2f}\n'
        self.test_acc = best_test_acc
        metrics["END_Test_acc"] = best_test_acc
        str += 'Misclassification Task\n'
        str += f'END_AUROC: {best_auroc:.2f}\n'
        metrics['END_AUROC'] = best_auroc
        str += f'END_AUPR_cor: {best_aupr_cor:.2f}\n'
        metrics['END_AUPR_cor'] = best_aupr_cor
        str += f'END_AUPR_err: {best_aupr_err:.2f}\n'
        metrics['END_AUPR_err'] = best_aupr_err
        str += f'END_FPR95: {best_fpr:.2f}\n'
        metrics['END_FPR95'] = best_fpr
        
        self.auc = best_auroc
        self.aupr_cor = best_aupr_cor
        self.aupr_err = best_aupr_err
        self.fpr95 = best_fpr

        return str, metrics, [self.test_acc, \
            self.auc, self.aupr_cor, self.aupr_err, self.fpr95]


