import os
import json
import torch
from datetime import datetime
import shutil
import yaml

from disentangle.model import MLPAE_1 as MLPAE


class Accumulator:
    def __init__(self, cnt):
        self.log_list = [[] for _ in range(cnt)]

    def add(self, *args):
        assert len(self.log_list) == len(args)
        for i, value in enumerate(args):
            self.log_list[i].append(value)
    
    def output(self):
        result = []
        for i, sub_list in enumerate(self.log_list):
            data = sum(sub_list) / len(sub_list)
            format_data = format(data, '.2f')
            result.append(format_data)
            self.log_list[i] = []
        return result


class Saver:

    dir_path = './disentangle/outputs/model/'
    log_name = './disentangle/outputs/logs.json'

    def __init__(self, config):
        self.config = config
        if not os.path.exists(self.dir_path):
            os.makedirs(self.dir_path)

        # make experiment directory
        expdir = f"{config['exp_name']}_{config['model_name']}_{config['from_layer']}-{config['to_layer']}s" \
            f"_{config['relation_layer']}r_{config['hidden_dims'][-1]}H"
        same_name_list = [name for name in os.listdir(self.dir_path) if name.split('__')[0] == expdir]
        if same_name_list:
            expdir = f'{expdir}__{len(same_name_list)}'
        os.mkdir(os.path.join(self.dir_path, expdir))
        self.expdir = expdir

        # copy config
        if config['model_name'] == 'gpt2-xl':
            shutil.copy('./disentangle/config_x.yaml', os.path.join(self.dir_path, expdir, 'config.yaml'))
        elif config['model_name'] == 'gpt-j':
            shutil.copy('./disentangle/config_j.yaml', os.path.join(self.dir_path, expdir, 'config.yaml'))
        elif config['model_name'] == 'llama3':
            shutil.copy('./disentangle/config_l.yaml', os.path.join(self.dir_path, expdir, 'config.yaml'))
        else:
            raise ValueError

    def save(self, model, ep, eval_str_loss):
        torch.save(model.state_dict(), os.path.join(self.dir_path, self.expdir, f'{ep}.pt'))

        # log experiment
        if ep == self.config.epoch - 1:
            torch.save(model.state_dict(), os.path.join(self.dir_path, self.expdir, f'model.pt'))
            if os.path.exists(self.log_name):
                with open(self.log_name, 'r', encoding='utf-8') as f:
                    logs = json.load(f)
            else:
                logs = []

            logs.append({
                'name': self.expdir,
                'date': datetime.now().strftime('%Y-%m-%d, %H:%M:%S'),
                'description': self.config.exp_dsc,
                'eval_loss': eval_str_loss
            })
            with open(self.log_name, 'w', encoding='utf-8') as f:
                json.dump(logs, f, indent=4)

    @staticmethod
    def init_ae_model(exp_name_or_config):
        do_load = isinstance(exp_name_or_config, str)
        if do_load:
            with open(f'./disentangle/outputs/model/{exp_name_or_config}/config.yaml', 'r', encoding='utf-8') as f:
                ae_config = yaml.load(f, Loader=yaml.FullLoader)
        else:
            ae_config = exp_name_or_config
        
        ae_model = MLPAE(hidden_dims=ae_config['hidden_dims'])
        if do_load:
            state_dict = torch.load(f'./disentangle/outputs/model/{exp_name_or_config}/model.pt')
            ae_model.load_state_dict(state_dict)

        return ae_config, ae_model

