import torch

class Logger(object):
    def __init__(self, runs, info=None, warmup=0):
        self.info = info
        self.warmup = warmup
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 4
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None, mode='max_acc'):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
            argmin = result[:, 3].argmin().item()
            if mode == 'max_acc':
                ind = argmax
            else:
                ind = argmin

            print_str=f'Run {run + 1:02d}:'+\
                f'Highest Train: {result[:, 0].max():.2f} '+\
                f'Highest Valid: {result[:, 1].max():.2f} '+\
                f'Highest Test: {result[:, 2].max():.2f} '+\
                f'Chosen epoch: {ind+1}\n'+\
                f'Final Train: {result[ind, 0]:.2f} '+\
                f'Final Test: {result[ind, 2]:.2f}'
            print(print_str)
            
        else:
            best_results = []
            max_val_epoch=0
            for r in self.results[self.warmup:]:
                r=100*torch.tensor(r)
                train1 = r[:, 0].max().item()
                test1 = r[:, 2].max().item()
                valid = r[:, 1].max().item()
                if mode == 'max_acc':
                    train2 = r[r[:, 1].argmax(), 0].item()
                    test2 = r[r[:, 1].argmax(), 2].item()
                    max_val_epoch=r[:, 1].argmax()
                else:
                    train2 = r[r[:, 3].argmin(), 0].item()
                    test2 = r[r[:, 3].argmin(), 2].item()
                best_results.append((train1, test1, valid, train2, test2))

            best_result = torch.tensor(best_results)

            print_str=f'{len(self.results)} runs: '
            r = best_result[:, 0]
            print_str+=f'Highest Train: {r.mean():.2f} ± {r.std():.2f} '
            print_str+=f'Highest val epoch:{max_val_epoch}\n'
            r = best_result[:, 1]
            print_str+=f'Highest Test: {r.mean():.2f} ± {r.std():.2f} '
            r = best_result[:, 4]
            print_str+=f'Final Test: {r.mean():.2f} ± {r.std():.2f}'

            self.test=r.mean()
        return print_str
    
    def output(self,out_path,info):
        with open(out_path,'a') as f:
            f.write(info)
            f.write(f'test acc:{self.test}\n')

