import torch
import numpy as np
from typing import OrderedDict
import json

def empty_cache(ratio):
    if ratio is None:
        return
    allocated = torch.cuda.memory_allocated(0)
    reserved = torch.cuda.memory_reserved(0)
    if reserved > 0 and allocated / reserved < ratio:
        torch.cuda.empty_cache()


def get_memory_usage(print_info=False):
    """Get accurate gpu memory usage by querying torch runtime"""
    allocated = torch.cuda.memory_allocated(0)
    reserved = torch.cuda.memory_reserved(0)
    if print_info:
        print("allocated: %.2f MB" % (allocated / 1024 / 1024), flush=True)
        print("reserved:  %.2f MB" % (reserved / 1024 / 1024), flush=True)
    return allocated


def compute_tensor_bytes(tensors):
    """Compute the bytes used by a list of tensors"""
    if not isinstance(tensors, (list, tuple)):
        tensors = [tensors]

    ret = 0
    for x in tensors:
        if x.dtype in [torch.float32, torch.int]:
            ret += np.prod(x.size()) * 4 
        elif x.dtype in [torch.bfloat16, torch.float16, torch.int16]:
            ret += np.prod(x.size()) * 2
        elif x.dtype in [torch.int8]:
            ret += np.prod(x.size())

    return ret


class GlobalExpRecorder:
    def __init__(self):
        self.val_dict = OrderedDict()

    def record(self, key, value, float_round=6):
        if isinstance(value, (np.int32, np.int64)):
            value = int(value)
        if isinstance(value, (float, np.float32, np.float64)):
            value = round(value, float_round)

        self.val_dict[key] = value

    def dump(self, filename):
        with open(filename, "a") as fout:
            fout.write(json.dumps(self.val_dict) + '\n')
        print("Save exp results to %s" % filename)

    def clear():
        pass

exp_recorder = GlobalExpRecorder()


DIFF_METER = lambda: Meter(AverageMeter(), AverageMeter(), AverageMeter())


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.n = 0
        self.val = 0

    def record(self, val, n=1, float_round=5):
        if isinstance(val, (float, np.float32, np.float64)):
            val = round(val, float_round)
        self.n += n
        self.val += val * n

    def get_val(self):
        if self.n == 0:
            return None, 0

        v = self.val / self.n
        if isinstance(v, (float, np.float32, np.float64)):
            v = round(v, 5)
        if isinstance(v, torch.Tensor):
            v = v.tolist()
        return v, self.n

    def get_data(self):
        if self.n == 0:
            return None, 0
        return self.val / self.n, self.n


class Meter(object):
    def __init__(self, iteration_aggregator, epoch_aggregator, run_aggregator):
        self.run_aggregator = run_aggregator
        self.epoch_aggregator = epoch_aggregator
        self.iteration_aggregator = iteration_aggregator

    def record(self, val, n=1):
        self.iteration_aggregator.record(val, n=n)

    def get_iteration(self):
        v, n = self.iteration_aggregator.get_val()
        return v

    def reset_iteration(self):
        v, n = self.iteration_aggregator.get_data()
        self.iteration_aggregator.reset()
        if v is not None:
            self.epoch_aggregator.record(v, n=n)

    def get_epoch(self):
        v, n = self.epoch_aggregator.get_val()
        return v

    def reset_epoch(self):
        v, n = self.epoch_aggregator.get_data()
        self.epoch_aggregator.reset()
        if v is not None:
            self.run_aggregator.record(v, n=n)

    def get_run(self):
        v, n = self.run_aggregator.get_val()
        return v

    def reset_run(self):
        self.run_aggregator.reset()


class GlobalOptimizationRecorder(object):
    def __init__(self) -> None:
        self.log_data = {}
        self.epoch = -1
        self.iteration = -1
    
    def set_start_epoch(self, start_epoch=-1):
        self.epoch = start_epoch
        print(f'set start epoch: {start_epoch}')

    def init(self, print_freq, raport_path, param_names):
        self.print_freq = print_freq
        self.raport_path = raport_path
        for tensor_name in param_names:
            self.log_data[tensor_name] = {}
        self.register_metric(
            "q_diff_norm",
            DIFF_METER,
        )
        self.register_metric(
            "q_diff_norm_with_lr",
            DIFF_METER,
        )
        self.register_metric(
            "q_diff_bin_change_num",
            DIFF_METER,
        )
        self.register_metric(
            "q_diff_bin_change_ratio",
            DIFF_METER,
        )
        self.register_metric(
            "q_diff_p_change_num",
            DIFF_METER,
        )
        self.register_metric(
            "q_diff_p_change_ratio",
            DIFF_METER,
        )
        self.register_metric(
            "q_diff_hit_max_num",
            DIFF_METER,
        )
        self.register_metric(
            "q_diff_hit_max_ratio",
            DIFF_METER,
        )
        self.register_metric(
            "q_diff_proba_hist",
            DIFF_METER,
        )
        self.register_metric(
            "q_diff_cos_sim_update",
            DIFF_METER,
        )

    def register_metric(self, metric_name, meter):
        for tensor_name in self.log_data:
            self.log_data[tensor_name][metric_name] = meter()

    def log_metric(self, tensor_name, metric_name, val, n=1):
        if tensor_name not in self.log_data:
            print(f'tensor {tensor_name} not in GOR')
            return
        self.log_data[tensor_name][metric_name].record(val, n=n)
    
    def start_iteration(self):
        self.iteration += 1
        
    def end_iteration(self):
        if self.iteration % self.print_freq == 0:
            iteration_data = {}
            for tensor_name in self.log_data:
                tensor_data = {
                    metric_name: self.log_data[tensor_name][metric_name].get_iteration() for metric_name in self.log_data[tensor_name]
                }
                iteration_data[tensor_name] = tensor_data
            iteration_data = {"epoch": self.epoch, "label": "iteration", "iteration":self.iteration, "data":iteration_data}
            with open(self.raport_path, "a") as fout:
                fout.write(json.dumps(iteration_data) + '\n')

            for tensor_name in self.log_data:
                for n, m in self.log_data[tensor_name].items():
                    m.reset_iteration()

    def start_epoch(self):
        self.epoch += 1
        self.iteration = 0

        for tensor_name in self.log_data:
            for n, m in self.log_data[tensor_name].items():
                m.reset_epoch()
    
    def end_epoch(self):
        epoch_data = {}
        for tensor_name in self.log_data:
            tensor_data = {
                metric_name: self.log_data[tensor_name][metric_name].get_epoch() for metric_name in self.log_data[tensor_name]
            }
            epoch_data[tensor_name] = tensor_data
        epoch_data = {
            "epoch": self.epoch,
            "label": "epoch",
            "data": epoch_data,
        }
        with open(self.raport_path, "a") as fout:
            fout.write(json.dumps(epoch_data) + '\n')

GOR = GlobalOptimizationRecorder()