import numpy as np
import pandas as pd
import torch

from src.global_constants import *
from src.utils.yparams import YParams
from src.models import build_model


class OutDf:
    """ Stores results as dataframes. Allow multiple dataframes to be stored
    e.g. as the result of a multiprocessed computation. """

    def __init__(self, dir_path):
        self.path = self.init_path(dir_path)
        self.d = {}

    @staticmethod
    def init_path(dir_path):
        random = np.random.randint(int(1e10),int(1e11))
        path = Path(dir_path) / f"{random}.csv"
        if path.is_file():
            raise ValueError("Extremely rare collision occured.")
        return path

    def add(self, **kwargs):
        """ Add a list of  """
        if self.d == {}:
            self.d = {k: [v] for (k, v) in kwargs.items()}
            return
        for k, v in kwargs.items():
            if k not in self.d:
                raise ValueError("Wrong key.")
            self.d[k].append(v)

    def load(self):
        """ Load all the dataframe results in the saved directory. """
        dir_path = self.path.parent
        if not dir_path.is_dir():
            return pd.DataFrame()
        if not any(dir_path.glob('*.csv')):
            return pd.DataFrame()
        dfl = []
        for fname in dir_path.glob('*.csv'):
            dfl.append(pd.read_csv(fname))
        return pd.concat(dfl)

    def save(self):
        """ Save only the current dataframe to the directory. """
        if self.d == {}:
            return
        self.path.parent.mkdir(parents=True, exist_ok=True)
        df = pd.DataFrame(self.d)
        df.to_csv(self.path, index=False)


def load_model(dirname, epoch=None, cuda=True, param_path=None, **kwargs):
    """ If epoch is None: get the best models in term of its val loss. """ 

    if epoch is None:
        model_file = "best_ckpt.tar"
    else:
        model_file = f"ckpt.tar_epoch{epoch}"
    exp_path = LOG_PATH / dirname 

    # params
    if param_path is None:
        param_path = exp_path / 'hyperparams.yaml'
    params = YParams(param_path)
    for k, v in kwargs.items():
        params[k] = v

    # model skeleton
    model = build_model(params).to(dtype=torch.float)
    if cuda:
        model = model.cuda()

    # checkpoint
    checkpt_path = exp_path / 'training_checkpoints' / model_file
    print("Loading: ", checkpt_path)
    if not cuda:
        chkpt = torch.load(checkpt_path, map_location=torch.device('cpu'))
    else:
        chkpt = torch.load(checkpt_path)
    state_dict = {k.replace('module.', '') : v for  (k,v) in chkpt['model_state'].items()}

    # fill model 
    model.load_state_dict(state_dict)
    model.eval()

    return model, params
