import os
import uuid
import json
from abc import ABC, abstractmethod

import torch
from collections import OrderedDict
from loguru import logger
from offlinerl.utils.exp import init_exp_logger
from offlinerl.utils.io import create_dir, download_helper, read_json


class BaseAlgo(ABC):
    def __init__(self, args):        
        logger.info('Init AlgoTrainer')
        if "exp_name" not in args.keys():
            exp_name = str(uuid.uuid1()).replace("-","")
        else:
            exp_name = args["exp_name"]
        
        if "aim_path" in args.keys():
            if os.path.exists(args["aim_path"]):
                repo = args["aim_path"]
        else:
            repo = None
        
        self.repo = repo
        self.exp_logger = init_exp_logger(repo = repo, experiment_name = exp_name)
        self.index_path = self.exp_logger.repo.index_path
        self.models_save_dir = os.path.join(self.index_path, "models")
        self.transitions_save_dir = os.path.join(self.index_path, "transitions")
        self.metric_logs = OrderedDict()
        self.metric_logs_path = os.path.join(self.index_path, "metric_logs.json")
        create_dir(self.models_save_dir)
        create_dir(self.transitions_save_dir)

        self.exp_logger.set_params(args, name='hparams')
        
    
    def log_res(self, epoch, result):
        logger.info('Epoch : {}', epoch)
        for k,v in result.items():
            logger.info('{} : {}',k, v)
            self.exp_logger.track(v, name=k.split(" ")[0], epoch=epoch,)
        
        self.metric_logs[str(epoch)] = result
        with open(self.metric_logs_path,"w") as f:
            json.dump(self.metric_logs,f)
        self.save_model(os.path.join(self.models_save_dir, str(epoch) + ".pt"))


    def log_transition(self, task):
        self.save_transition(os.path.join(self.transitions_save_dir, str(task) + ".pt"))

            
    
    @abstractmethod
    def train(self, 
              history_buffer,
              eval_fn=None,):
        pass
    
    def _sync_weight(self, net_target, net, soft_target_tau = 5e-3):
        for o, n in zip(net_target.parameters(), net.parameters()):
            o.data.copy_(o.data * (1.0 - soft_target_tau) + n.data * soft_target_tau)
    
    @abstractmethod
    def get_policy(self,):
        pass

    @abstractmethod
    def get_transition(self,):
        pass
    
    #@abstractmethod
    def save_model(self, model_path):
        torch.save(self.get_policy(), model_path)
        
    #@abstractmethod
    def load_model(self, model_path):
        model = torch.load(model_path)
        return model

    # @abstractmethod
    def save_transition(self, transition_path):
        torch.save(self.get_transition(), transition_path)

    # @abstractmethod
    def load_transition(self, transition_path):
        transition = torch.load(transition_path)
        return transition