import os
import math
import time
import random
import datetime
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict


class LossForPrint:
    def __init__(self):
        self.loss_dict = OrderedDict()
    
    def update(self, current_loss):
        for key in current_loss.keys():
            if key not in self.loss_dict:
                self.loss_dict[key] = current_loss[key].item()
            else:
                self.loss_dict[key] += current_loss[key].item()
    
    def get_loss(self, key):
        return self.loss_dict[key]

    def compute(self, func):
        for key in self.loss_dict.keys():
            self.loss_dict[key] = func(self.loss_dict[key])
    
    def process_print(self, prefix):
        print_results = f"{prefix}_{next(iter(self.loss_dict))}={self.loss_dict[next(iter(self.loss_dict))]:.4f}"
        for key in list(self.loss_dict.keys())[1:]:
            print_results += f"  {prefix}_{key}={self.loss_dict[key]:.4f}"
        
        return print_results
    
    def clear(self):
        self.loss_dict.clear()


def set_deterministic(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    torch.use_deterministic_algorithms(True)
