import logging
from typing import Union, Optional
import os
import numpy as np
import yaml
from tqdm import tqdm
import torch
from torch.nn import Module
from torch.optim import Optimizer
from torch.nn.parallel import DistributedDataParallel as DDP

import logger

class ValueMeter:
    def __init__(self):
        self.value_list = []
    
    def update(self, value):
        if isinstance(value, torch.Tensor):
            self.value_list.append(value.item())
        elif isinstance(value, np.ndarray):
            self.value_list.append(float(value))
        else:
            self.value_list.append(value)
        
    def avg(self):
        return np.mean(self.value_list)

class MetricHanlder:
    def __init__(self, primary_metric: str = "accuracy", is_max: bool = True, return_epoch: bool = False):
        self.primary_metric = primary_metric
        self.is_max = is_max
        self.return_epoch = return_epoch

        self.curr_metric = -1
        self.curr_epoch = -1
        self.best_metric = 0
        self.best_epoch = -1
    
    def update(self, metric: Union[float, int, dict], epoch: Optional[int] = None) -> bool:
        is_best = False
        # update epoch
        if epoch is None:
            self.curr_epoch += 1
        else:
            self.curr_epoch = epoch
        
        # update metric
        if isinstance(metric, dict):
            self.curr_metric = metric[self.primary_metric]
        elif isinstance(metric, (float, int)):
            self.curr_metric = metric
        
        if (self.is_max and self.curr_metric > self.best_metric) or (not self.is_max and self.curr_metric < self.best_metric):
            self.best_metric = self.curr_metric
            self.best_epoch = self.curr_epoch
            is_best = True
        
        if self.return_epoch:
            return is_best, self.best_epoch
        else:
            return is_best


def save_all(save_path: str, epoch: int, model: Module, criterion: Module, optimizer: Optimizer, scheduler = None):
    if isinstance(model, DDP):
        model = model.module
    if scheduler is None:
        sched_params = None
    else:
        sched_params = scheduler.state_dict()
    model_dict = {
        "epoch": epoch,
        "model": model.state_dict(),
        "criterion": criterion.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": sched_params
    }
    save_path = os.path.join(save_path, "model.ckpt")
    torch.save(model_dict, save_path)

def load_all(device, load_path: str, model: Module, criterion: Optional[Module] = None, optimizer: Optional[Optimizer] = None, scheduler = None) -> int:
    load_path = os.path.join(load_path, "model.ckpt")
    model_dict = torch.load(load_path, map_location=device)
    if isinstance(model, DDP):
        model.module.load_state_dict(model_dict["model"])
    else:
        model.load_state_dict(model_dict["model"])
    if criterion is not None:
        criterion.load_state_dict(model_dict["criterion"])
    if optimizer is not None:
        optimizer.load_state_dict(model_dict["optimizer"])
    if scheduler is not None:
        scheduler.load_state_dict(model_dict["scheduler"])
    
    return model_dict["epoch"]
    
def resume_config(load_path: str, config, preserve_list = ["rank"]):
    load_path = os.path.join(load_path, "config.yaml")
    config_dict = yaml.load(open(load_path), Loader=yaml.FullLoader)
    for k, v in config_dict.items():
        if k not in preserve_list:
            if hasattr(config, k):
                setattr(config, k, v)
                logger.info(f"Update {k}: {v}")

def give_pbar(loader, rank):
    # if rank == 0:
    #     return tqdm(loader)
    # else:
    return iter(loader)


if __name__ == "__main__":
    handler = MetricHanlder(return_epoch=True)
    print(handler.update(0.2))
    print(handler.update(0.4))
    print(handler.update(0.3))
    print(handler.update(0.8))
    