import os
import torch
import sys
import time


class Logger(object):
    """ Adapted from https://github.com/snap-stanford/ogb/ """

    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]
        self.final_test = []
        
        if not self.info.save_result:
            filepath = f"results/{info.dataset}/" + time.strftime(f"%m_%d/")
            if not os.path.exists(filepath):
                os.makedirs(filepath)

            filepath += (time.strftime("%H_%M_%S") + f"_{info.split}.log")
            self.file = open(filepath, "a+")
            self.savedout = sys.stdout
        
    def write(self, *msg):
        print(*msg)
        if not self.info.save_result:
            sys.stdout = self.file
            print(*msg)
            self.file.flush()
            sys.stdout = self.savedout
    
    def flush(self):
        pass

    def add_result(self, run, result):
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None, mode='max_acc'):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax_ori = result[:, 1].argmax().item()
            argmin_ori = result[:, 1].argmin().item()
            if mode == 'max_acc':
                ind_ori = argmax_ori
            else:
                ind_ori = argmin_ori
            self.write(f'Run {run + 1:02d}:')
            self.write(f'Highest Train Ori: {result[:, 0].max():.2f}')
            self.write(f'Highest Valid Ori: {result[:, 1].max():.2f}')
            self.write(f'Highest Test Ori: {result[:, 2].max():.2f}')
            self.write(f'Ori Chosen epoch: {ind_ori}')
            self.write(f'Ori Final Train: {result[ind_ori, 0]:.2f}')
            self.write(f'Ori Final Test: {result[ind_ori, 2]:.2f}')

            self.final_test.append([result[ind_ori, 2]])
        else:
            self.final_test = torch.tensor(self.final_test)
            self.write(f'All runs:')
            r = self.final_test[:, 0]
            self.write(f'Final Ori Test: {r.mean():.2f} ± {r.std():.2f}')

            self.test = r.mean()
            return self.final_test


def save_model(args, model, optimizer, run):
    if not os.path.exists(f'models/{args.dataset}'):
        os.makedirs(f'models/{args.dataset}')
    model_path = f'models/{args.dataset}/{args.postfix}_{run}.pt'
    torch.save({'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
                }, model_path)


def load_model(args, model, run):
    model_path = f'models/{args.dataset}/{args.postfix}_{run}.pt'
    checkpoint = torch.load(model_path, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    return model


def save_result(args, results):
    postfix = args.split if args.semi else "CL"
    filename = f'results/{args.dataset}_grid/MMFormer_{postfix}.csv'
    if not os.path.exists(f'results/{args.dataset}_grid'):
        os.makedirs(f'results/{args.dataset}_grid')

    if not os.path.exists(filename):
        with open(f"{filename}", 'a+') as f1:
            f1.write(
                "h_dim,layers,n_head,num_clusters,global_nodes_per_class,lr,wd,dropout,attn_dropout,norm_type,norm_pos,avg_ori,std_ori \n")
    print(f"Saving results to {filename}")
    with open(f"{filename}", 'a+') as f2:
        f2.write(
            f"{args.h_dim},{args.layers},{args.n_head},{args.num_clusters},{args.global_nodes_per_class},{args.lr},{args.wd},{args.dropout},{args.attn_dropout}" +
            f",{args.norm_type},{args.norm_pos},{results[:, 0].mean():.2f},{results[:, 0].std():.2f} \n")