import torch

class LoggerSingleRun(object):
    def __init__(self, use_wandb=False):
        self.results = []
        self.use_wandb = use_wandb
        if use_wandb:
            import wandb

    def add_result(self, result):
        self.results.append(result)

    def print_statistics(self):
        result = 100 * torch.tensor(self.results)
        argmax = result[:, 1].argmax().item()
        highest_valid = result[:, 1].max()
        final_test = result[argmax, 2]
        print(f'Highest Train: {result[:, 0].max():.2f}')
        print(f'Highest Valid: {highest_valid:.2f}')
        print(f'  Final Train: {result[argmax, 0]:.2f}')
        print(f'   Final Test: {final_test:.2f}')
        return highest_valid, final_test

class LoggerSingleRunUnseen(object):
    def __init__(self, use_wandb=False):
        self.results = []
        self.use_wandb = use_wandb
        if use_wandb:
            import wandb

    def add_result(self, result):
        self.results.append(result)

    def print_statistics(self):
        result = 100 * torch.tensor(self.results)
        argmax = result[:, 1].argmax().item()
        highest_valid = result[:, 1].max()
        final_seen_test = result[argmax, 2]
        final_unseen_test = result[argmax, 3]
        print(f'    Highest Train: {result[:, 0].max():.2f}')
        print(f'    Highest Valid: {highest_valid:.2f}')
        print(f'      Final Train: {result[argmax, 0]:.2f}')
        print(f'  Final Seen Test: {final_seen_test:.2f}')
        print(f'Final Unseen Test: {final_unseen_test:.2f}')
        return highest_valid, final_seen_test, final_unseen_test
        


class Logger(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
            print(f'Run {run + 1:02d}:')
            print(f'Highest Train: {result[:, 0].max():.2f}')
            print(f'Highest Valid: {result[:, 1].max():.2f}')
            print(f'  Final Train: {result[argmax, 0]:.2f}')
            print(f'   Final Test: {result[argmax, 2]:.2f}')
        else:
            result = 100 * torch.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = torch.tensor(best_results)

            print(f'All runs:')
            r = best_result[:, 0]
            print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 1]
            print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 2]
            print(f'  Final Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 3]
            print(f'   Final Test: {r.mean():.2f} ± {r.std():.2f}')
