from pathlib import Path
import numpy as np
import os
import torch
from _llm_based.objects.model_specs import model_specs


def model_params():
    gpt2_param_prefix = 'transformer.h.'
    gpt2_attn_params = ['attn.c_attn.weight', 'attn.c_proj.weight']
    gpt2_mlp_params = ['mlp.c_fc.weight', 'mlp.c_proj.weight']
    gpt2_params = {'attn': gpt2_attn_params, 'mlp': gpt2_mlp_params}
    mistral_param_prefix = 'model.layers.'
    mistral_attn_params = ['self_attn.k_proj.weight', 'self_attn.q_proj.weight', 'self_attn.v_proj.weight',
                           'self_attn.o_proj.weight']
    mistral_mlp_params = ['mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight']
    mistral_params = {'attn': mistral_attn_params, 'mlp': mistral_mlp_params}

    model_params = {'GPT2': [gpt2_param_prefix, gpt2_params], 'MistralOo': [mistral_param_prefix, mistral_params]}
    return model_params


def get_param(model_instance, sample_config, layer_idx=0, param_type='mlp', param_idx=0):
    param_prefix = sample_config.params_dict[0]
    param_name = sample_config.params_dict[1][param_type][param_idx]
    param_name = param_prefix + f"{layer_idx}." + param_name
    print(f"Parameter name: {param_name}")
    with torch.no_grad():
        all_params = list(set(model_instance.state_dict().keys()))
    is_valid_param = param_name in all_params
    if not is_valid_param:
        print(f"\tParameter name not valid!")
        param = None
    else:
        param = model_instance.get_parameter(param_name)
        param.requires_grad = True

    return param, param_name


class ModelInfo:
    '''
    defines model list and colormap
    '''

    def __init__(self):
        self._colours = ['tab:blue']
        self._colours_comp = ['tab:blue', 'tab:red', 'tab:purple']
        self.models = ['Mistral-7B-OpenOrca-GPTQ']
        self.col_dict = {ds: self._colours[i] for i, ds in enumerate(self.models)}


class SampleConfig:
    def __init__(self, paths, model_name='GPT2', qs_name='sds', gen_qs_name=None, instr_name='instr1', skip_sui=True,
                 temp=1.25,
                 top_p=0.896, max_new_tok=10):
        self._prompts_path = None
        self.subj = 'sub0'
        self.which_q = 'lvl0_q0'
        self.which_gen_q = None

        self._paths = paths
        # self._context_supdir = None
        self._prefix_path = None

        self.model_name_sshort = model_name
        # self.model_params = model_params()

        # set model specs for hugging face
        for k, v in model_specs[model_name].items():
            setattr(self, k, v)

        self.midL = int(self.L * 0.6)
        self.layer_list = list(set(np.append(np.arange(2, self.L, 3), self.L).tolist()))

        self.hsT = 20 # how many timepoints for hidden states trajectory

        self.type = 'model'
        self.save_states = False

        self._model_name_r = None  # model name + revision
        self._source_name = None  # short model name (no revision)

        # self.model_objects_dir = 'selected/'
        self.instr_name = instr_name
        # self.prompt_type = 'prompt_conv'
        # self.prefix_name = 'conv_scale'
        # self.prompt_type = 'prompt_list'
        # outputs self.prefix_name = 'list'
        # self.qs_name = qs_name
        self.temp = temp
        self.top_p = top_p
        self._max_new_tok = None

        self._cond = None  # condition temp^topp
        self._id_filter = None  # filter tmp^topp^model^prompt

        self._model_name_rp = None  # model name + revision  + prompt
        self._model_name_rhp = None  # model name + revision + hyper-param pair + prompt
        self._model_name_hf = None  # HuggingFace model name
        self._prompt_name = None

        self.nSamples = 2
        self.batchSize = 10

        self.remSamples = None
        self.remBatchSize = None
        self.remBatches = None

        self.runMe = True
        self.qs_name = qs_name
        self.skip_sui = skip_sui
        self.gen_qs_name = gen_qs_name  # name of the questionnaire to sample answers to

        self.openq_fname = None  # name of file with open questions
        self.customq_fname = None  # name of file with custom closed questions
        self.context_name = None  # name of the context used before sampling resposne (either level or arbitrary perm.)
        self.gen_sample_name = None  # name of the set of samples for dir creation

        # self.context_name = 'depression'
        # self.context_instance = 1
        # self.context_prompt_name = 'context_prompt'
        # self.context_name = 'none'
        # self.context_instance = 'none'
        # self.context_prompt_name = 'none'

        self._responses_path = None
        self._outputs_path = None
        self._states_path = None

        # self._results_path = None
        # self._weights_path = None
        # self._responses_path = None
        # self._results_path = None
        # self._objects_save_path = None
        # self._objects_save_fname = None
        # self._time_path = None
        # self._context_path = None
        # self._context_full_name = None

        # self.responses_path = None
        # self.outputs_path = None
        # self.time_path = None
        # self.results_path = None
        # self.loss_path = None
        # self.loss_plot_path = None
        # self.weights_path = None
        # self.grads_path = None

        # self._objects_save_path = None
        # self._objects_save_fname = None

        # self.params_dict = model_params()[self.model_name_sshort]

        self.max_new_tok = max_new_tok

    def get_remaining_samples(self, printMe=True):
        try:
            count_ok_el = len([ol for ol in (os.listdir(self.responses_path)) if '.csv' in ol])
        except:
            count_ok_el = 0
        try:
            count_fail_el = len([ol for ol in (os.listdir(self.responses_path + '/fails')) if '.csv' in ol])
        except:
            count_fail_el = 0
        # count_total_el = count_ok_el + count_fail_el
        count_total_el = count_ok_el
        self.remSamples = np.max([0, self.nSamples - count_total_el])
        self.remBatchSize = np.min([self.batchSize, self.remSamples])
        if self.remSamples > 0 and self.remSamples % self.batchSize > 0:
            self.remBatches = np.max([0, self.remSamples // self.batchSize + 1])
        elif self.remSamples > 0 and self.remSamples % self.batchSize == 0:
            self.remBatches = self.remSamples // self.batchSize
        else:
            self.remBatches = 0
        if printMe:
            print(
                f"Rem samples: {self.remSamples}, Rem batches: {self.remBatches}, Rem batch size: {self.remBatchSize}")

    # @property
    # def max_new_tok(self):
    #     self._max_new_tok = 10
    #     # if self.gen_qs_name=='ids30'
    #     # if self.prompt_type == 'prompt_list':
    #     #     self._max_new_tok = 400
    #     # else:
    #     #     self._max_new_tok = 10
    #     return self._max_new_tok

    # @property
    # def prompt_name(self):
    #     self._prompt_name = '^'.join([self.prompt_type, self.prefix_name])
    #     return self._prompt_name

    # @property
    # def context_full_name(self):
    #     if self.context_name != 'none':
    #         self._context_full_name = f"{self.context_name}_context_{self.context_instance}"
    #     else:
    #         self._context_full_name = None
    #     return self._context_full_name

    # @property
    # def context_path(self):
    #     if self.context_name != 'none':
    #         self._context_path = '/'.join(
    #             ['context', f"{self.context_name}_context", self.context_full_name])
    #         pass
    #     else:
    #         self._context_path = 'no_context'
    #         pass
    #     return self._context_path

    # @property
    # def prefix_path(self):
    #     self._context_supdir = 'no_context'
    #     if self.context_name != 'none':
    #         self._context_supdir = 'context'
    #     self._prefix_path = '/'.join(['prompts', self._context_supdir, 'prefix']) + '/'
    #
    #     return self._prefix_path

    @property
    def model_name_r(self):
        self._model_name_r = '^'.join([self.model_name, self.revision])
        return self._model_name_r

    @property
    def model_name_hf(self):
        self._model_name_hf = '/'.join([self.hf_host, self.model_name])
        return self._model_name_hf

    @property
    def source_name(self):
        self._source_name = self.model_name_r.split('^')[0]
        return self._source_name

    @property
    def cond(self):
        self._cond = 'tpv-' + str(self.temp) + '^topp-' + str(self.top_p)
        return self._cond

    @property
    def model_name_rp(self):
        self._model_name_rp = '^'.join([self.model_name_r, self.instr_name])
        return self._model_name_rp

    @property
    def model_name_rhp(self):
        self._model_name_rhp = '^'.join(
            [self.model_name_r, self.cond, self.instr_name])
        return self._model_name_rhp

    @property
    def id_filter(self):
        # self._id_filter = '^'.join([str(self.temp), str(self.top_p), self.model_name_lp])
        self._id_filter = '^'.join([self._cond, self.model_name_rp])
        return self._id_filter

    # @property
    # def sup_path(self):
    #     self._sup_path = self._paths.files_dir + '/'.join(
    #         [self.qs_name, self.context_path, self.model_name_rp]) + '/'
    #     Path(self._sup_path).mkdir(parents=True, exist_ok=True)
    #     return self._sup_path

    @property
    def outputs_path(self):
        if self.gen_qs_name is None:
            # for sampling closed version question of the open Q-A as context
            if self.subj != 'sub0' and self.which_q != 'lvl0_q0':
                # self._outputs_path = f"{self._paths.subj_outputs_dir}{self.subj}/{self.which_q}/{self.model_name_rhp}/"
                self._outputs_path = f"{self._paths.files_dir}outputs/subjects/{self.subj}/{self.which_q}/{self.model_name_rhp}/"
                Path(self._outputs_path).mkdir(parents=True, exist_ok=True)
                Path(self._outputs_path + 'fails/').mkdir(parents=True, exist_ok=True)
        else:
            if self.context_name is None:
                # for sampling genearliastion questionnaire question given specific open Q-A as context
                if self.subj != 'sub0' and self.which_q != 'lvl0_q0' and self.gen_qs_name is not None:
                    self._outputs_path = f"{self._paths.files_dir}outputs/{self.gen_qs_name}/subjects/{self.subj}/{self.which_q}/{self.which_gen_q}/{self.model_name_rhp}/"
                    Path(self._outputs_path).mkdir(parents=True, exist_ok=True)
                    Path(self._outputs_path + 'fails/').mkdir(parents=True, exist_ok=True)
            else:
                # for sampling arbitrary set of questions given arbitrary context
                if self.subj != 'sub0':
                    self._outputs_path = f"{self._paths.files_dir}{self.context_name}/{self.gen_sample_name}/outputs/subjects/{self.subj}/{self.model_name_rhp}/"
                    Path(self._outputs_path).mkdir(parents=True, exist_ok=True)
                    Path(self._outputs_path + 'fails/').mkdir(parents=True, exist_ok=True)
        return self._outputs_path

    @property
    def responses_path(self):
        # if self.subj != 'sub0':
        #     self._responses_path = f"{self._paths.subj_responses_dir}{self.subj}/{self.which_q}/{self.model_name_rhp}/"
        #     Path(self._responses_path).mkdir(parents=True, exist_ok=True)
        #     Path(self._responses_path + 'fails/').mkdir(parents=True, exist_ok=True)
        # return self._responses_path

        if self.gen_qs_name is None:
            if self.subj != 'sub0':
                # self._responses_path = f"{self._paths.subj_responses_dir}{self.subj}/{self.which_q}/{self.model_name_rhp}/"
                self._responses_path = f"{self._paths.files_dir}responses/subjects/{self.subj}/{self.which_q}/{self.model_name_rhp}/"
                Path(self._responses_path).mkdir(parents=True, exist_ok=True)
                Path(self._responses_path + 'fails/').mkdir(parents=True, exist_ok=True)
        else:
            if self.context_name is None:
                if self.subj != 'sub0' and self.which_q != 'lvl0_q0' and self.gen_qs_name is not None:
                    self._responses_path = f"{self._paths.files_dir}responses/{self.gen_qs_name}/subjects/{self.subj}/{self.which_q}/{self.which_gen_q}/{self.model_name_rhp}/"
                    Path(self._responses_path).mkdir(parents=True, exist_ok=True)
                    Path(self._responses_path + 'fails/').mkdir(parents=True, exist_ok=True)
            else:
                # for sampling arbitrary set of questions given arbitrary context
                if self.subj != 'sub0':
                    self._responses_path = f"{self._paths.files_dir}{self.context_name}/{self.gen_sample_name}/responses/subjects/{self.subj}/{self.model_name_rhp}/"
                    Path(self._responses_path).mkdir(parents=True, exist_ok=True)
                    Path(self._responses_path + 'fails/').mkdir(parents=True, exist_ok=True)
        return self._responses_path

    @property
    def states_path(self):
        # if self.subj != 'sub0':
        #     self._states_path = f"{self._paths.subj_states_dir}{self.subj}/{self.which_q}/{self.model_name_rhp}/"
        #     Path(self._states_path).mkdir(parents=True, exist_ok=True)
        # return self._states_path
        if self.gen_qs_name is None:
            if self.subj != 'sub0':
                # self._states_path = f"{self._paths.subj_states_dir}{self.subj}/{self.which_q}/{self.model_name_rhp}/"
                self._states_path = f"{self._paths.files_dir}states/subjects/{self.subj}/{self.which_q}/{self.model_name_rhp}/"
                Path(self._states_path).mkdir(parents=True, exist_ok=True)

        else:
            if self.context_name is None:
                if self.subj != 'sub0' and self.which_q != 'lvl0_q0' and self.gen_qs_name is not None:
                    self._states_path = f"{self._paths.files_dir}states/{self.gen_qs_name}/subjects/{self.subj}/{self.which_q}/{self.which_gen_q}/{self.model_name_rhp}/"
                    Path(self._states_path).mkdir(parents=True, exist_ok=True)
            else:
                # for sampling arbitrary set of questions given arbitrary context
                if self.subj != 'sub0':
                    self._states_path = f"{self._paths.files_dir}{self.context_name}/{self.gen_sample_name}/states/subjects/{self.subj}/{self.model_name_rhp}/"
                    Path(self._states_path).mkdir(parents=True, exist_ok=True)
        return self._states_path

    @property
    def prompts_path(self):
        self._prompts_path = f"{self._paths.prompts_path}"
        return self._prompts_path


class ModelConfig(SampleConfig):
    def __init__(self, paths, qs_config, model_name='GPT2', qs_name='sds', temp=1.25, top_p=0.896):
        super().__init__(paths, model_name=model_name, qs_name=qs_name, temp=temp, top_p=top_p)

        self._responses_suppath = None
        self._qs_config = qs_config
        self._n_factors = None  # number of factors for the questionnaire
        # self.responses_suppath = None

    # def responses_path_fn(self):
    #     self.responses_path = self.sample_config._paths.files_dir + '/'.join(
    #         ['responses', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
    #          self.sample_config.model_name_rhp,
    #          self.method + '_' + str(self.lr), self.param_name]) + '/'
    #     Path(self.responses_path).mkdir(parents=True, exist_ok=True)
    #     Path(self.responses_path + 'fails/').mkdir(parents=True, exist_ok=True)
    #     self.sample_config.responses_path = self.responses_path
    #     # return self.responses_path
    @property
    def responses_suppath(self):
        self._responses_suppath = self._paths.files_dir + '/'.join(
            ['responses', self.qs_name, self.context_path, self.model_name_rp,
             self.model_name_rhp]) + '/'
        # Path(self._responses_path).mkdir(parents=True, exist_ok=True)
        # Path(self._responses_path + 'fails/').mkdir(parents=True, exist_ok=True)

        return self._responses_suppath

    @property
    def s_max(self):
        self._s_max = self._qs_config.qs_max_total[self.qs_name]
        return self._s_max

    @property
    def n_factors(self):
        self._n_factors = self._qs_config.qs_factors[self.qs_name]
        return self._n_factors
    #
    # @property
    # def objects_save_path(self):
    #     self._objects_save_path = self._paths.model_objects_save_dir + self.qs_name + '/' + self.model_name_lp + '/'
    #     if 'context' not in self._paths.files_dir:
    #         Path(self._objects_save_path).mkdir(parents=True, exist_ok=True)
    #     return self._objects_save_path
    #
    # @property
    # def objects_save_fname(self):
    #     self._objects_save_fname = self.model_name_lhp + '.pickle'
    #     return self._objects_save_fname

    # @property
    # def id_filter(self):
    #     # self._id_filter = '^'.join([str(self.temp), str(self.top_p), self.model_name_lp])
    #     self._id_filter = '^'.join([self._cond, self.model_name_lp])
    #     return self._id_filter


class SampleUpdateConfig(SampleConfig):
    def __init__(self, paths, qs_config, model_name='GPT2', qs_name='sds', temp=1.25, top_p=0.896):
        super().__init__(paths, model_name=model_name, qs_name=qs_name, temp=temp, top_p=top_p)

        self._responses_suppath = None
        self._qs_config = qs_config
        self._n_factors = None  # number of factors for the questionnaire
        self.parameter = None
        self.method = None
        self.lr = None
        self.qs_name = qs_name
        # self.responses_suppath = None

    def update_method(self, method, param_name, lr, qs_name='same'):
        self.param_name = param_name
        self.method = method
        self.lr = lr
        if qs_name != 'same':
            self.qs_name = qs_name

    @property
    def responses_path(self):
        self._responses_path = self._paths.files_dir + '/'.join(
            ['responses', self.qs_name, self.context_path, self.model_name_rp,
             self.model_name_rhp,
             self.method + '_' + str(self.lr), self.param_name]) + '/'
        Path(self._responses_path).mkdir(parents=True, exist_ok=True)
        Path(self._responses_path + 'fails/').mkdir(parents=True, exist_ok=True)
        # self.responses_path = self.responses_path
        return self._responses_path

    @property
    def outputs_path(self):
        self._outputs_path = self._paths.files_dir + '/'.join(
            ['outputs', self.qs_name, self.context_path, self.model_name_rp,
             self.model_name_rhp,
             self.method + '_' + str(self.lr), self.param_name]) + '/'
        Path(self._outputs_path).mkdir(parents=True, exist_ok=True)
        Path(self._outputs_path + 'fails/').mkdir(parents=True, exist_ok=True)
        # self.outputs_path = self.outputs_path
        return self._outputs_path

    # @property
    # def time_path(self):
    #     self._time_path = self._paths.files_dir + '/'.join(
    #         ['time', self.qs_name, self.context_path, self.model_name_rp,
    #          self.model_name_rhp,
    #          self.method + '_' + str(self.lr), self.param_name]) + '/'
    #     Path(self._time_path).mkdir(parents=True, exist_ok=True)
    #     Path(self._time_path + 'fails/').mkdir(parents=True, exist_ok=True)
    #     # self.time_path = self.time_path
    #     return self._time_path

    @property
    def results_path(self):
        self._results_path = self._paths.files_dir + '/'.join(
            ['results', self.qs_name, self.context_path, self.model_name_rp,
             self.method + '_' + str(self.lr), self.param_name]) + '/'
        Path(self._results_path).mkdir(parents=True, exist_ok=True)
        # self.results_path = self.results_path
        return self._results_path
