import pingouin as pg
from pathlib import Path
import numpy as np

from sklearn.decomposition import KernelPCA
from sklearn.decomposition import PCA
import umap
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler
from factor_analyzer import FactorAnalyzer
from utils.utils import *
from utils.sampling_utils import *
import pickle


class AnalysisConfig:
    def __init__(self, sample_config):
        self.layer_hs_dimred_sim_avg = None
        self.store_scores_foldavg = None
        self.best_layer_hs_cat = None
        self.best_layer_spec = None
        self.dim_red_methods = None
        self.nD = None
        self.nC = None
        self.x_test_layer = None
        self.x_train_layer = None
        self.x_test = None
        self.y_test = None
        self.y_train = None
        self.x_train = None
        self.ids = None
        self.responses = None
        self.h_states = None
        self.h_states_cat = None
        self.sample_config = sample_config

        self.nS = None
        self.nQ = None
        self.nR = None
        self.nL = None
        self.nF = 5

        self.h_analysis_path = re.sub('responses', 'h_results', self.sample_config.responses_path)
        self.sample_plots_path = re.sub('files', 'plots', self.h_analysis_path)
        Path(self.h_analysis_path).mkdir(parents=True, exist_ok=True)
        Path(self.sample_plots_path).mkdir(parents=True, exist_ok=True)

    def load_files(self):
        tensor_files = [self.sample_config.states_path + el for el in os.listdir(self.sample_config.states_path) if
                        '.pt' in el]
        response_files = [self.sample_config.responses_path + el for el in os.listdir(self.sample_config.responses_path)
                          if
                          '.csv' in el]

        for t in range(len(tensor_files)):
            tensor_file = tensor_files[t]
            sample_ts = tensor_file.split('-')[-1].split('.')[0]
            response_file = [el for el in response_files if sample_ts in el]
            if len(response_file) == 1:
                response_file = response_file[0]
                hs = torch.load(tensor_file, map_location=torch.device('cpu'))[:, 1:, :].numpy()[:, np.newaxis, :, :]
                rs = pd.read_csv(response_file).iloc[:, 2:3].to_numpy()
                if t == 0:
                    store_hs = hs
                    store_rs = rs
                else:
                    store_hs = np.concatenate((store_hs, hs), axis=1)
                    store_rs = np.concatenate((store_rs, rs), axis=1)
            else:
                print('error matching files')

        self.responses = store_rs
        self.h_states = store_hs
        self.h_states_cat = self.h_states.reshape((-1, self.h_states.shape[2], self.h_states.shape[3]))
        self.ids = np.arange(self.h_states_cat.shape[0])

        self.nL = self.h_states.shape[-2]
        self.nQ = self.h_states.shape[0]
        self.nS = self.h_states.shape[1]
        self.nR = self.nQ * self.nS
        self.nC = 30
        self.nD = self.h_states.shape[-1]

    def std_norm(self):
        # standardise responses
        self.responses = self.responses.reshape((-1, 1))
        rs_mean = self.responses.mean()
        rs_std = self.responses.std()
        self.responses = (self.responses - rs_mean) / rs_std

        # standardise hidden states
        hs_mean = self.h_states_cat.mean(axis=0)[np.newaxis, :]
        hs_std = self.h_states_cat.std(axis=0)[np.newaxis, :]
        self.h_states_cat = (self.h_states_cat - hs_mean) / hs_std

        # normalise hidden states
        mag = np.linalg.norm(self.h_states_cat, axis=2)[:, :, np.newaxis]
        self.h_states_cat = self.h_states_cat / mag
        self.h_states_cat = self.h_states_cat.reshape(self.h_states_cat.shape)

    def split(self):
        np.random.shuffle(self.ids)
        train_len = int(self.h_states_cat.shape[0] * 0.8)
        train_ids = self.ids[:train_len]
        test_ids = self.ids[train_len:]
        self.test_len = len(test_ids)
        self.x_train, self.y_train = self.h_states_cat[train_ids], self.responses[train_ids]
        self.x_test, self.y_test = self.h_states_cat[test_ids], self.responses[test_ids]

    def get_layer_split(self, l):
        self.x_train_layer, self.x_test_layer = self.x_train[:, l, :], self.x_test[:, l, :]

    def dim_red_method(self, name):
        dim_red_methods = {'pca': PCA(n_components=self.nC), 'kpca': KernelPCA(n_components=self.nC, kernel='cosine'),
                           'umap': umap.UMAP(n_components=self.nC, metric='correlation')}
        if name != 'base':
            dimred = dim_red_methods[name]
            self.x_train_layer_dimred = dimred.fit_transform(self.x_train_layer)
            self.x_test_layer_dimred = dimred.transform(self.x_test_layer)
        else:
            dimred = None
            self.x_train_layer_dimred = self.x_train_layer
            self.x_test_layer_dimred = self.x_test_layer

        reg = LinearRegression().fit(self.x_train_layer_dimred, self.y_train)
        self.y_test_layer_pred_dimred = reg.predict(self.x_test_layer_dimred)[:, -1]
        self.reg_score = reg.score(self.x_test_layer_dimred, self.y_test)
        self.reg_corr = np.corrcoef((self.y_test[:, -1], self.y_test_layer_pred_dimred))[0, 1]

    def get_dimred_layer_sim(self, l, name):
        dim_red_methods = {'pca': PCA(n_components=self.nC), 'kpca': KernelPCA(n_components=self.nC, kernel='cosine'),
                           'umap': umap.UMAP(n_components=self.nC, metric='correlation')}

        layer_hs_cat = self.h_states_cat[:, l, :]
        if name in dim_red_methods.keys():
            dimred = dim_red_methods[name]
            layer_hs_cat_dimred = dimred.fit_transform(layer_hs_cat)
            layer_hs_dimred = layer_hs_cat_dimred.reshape(self.nQ, self.nS, -1)

            layer_hs_dimred_sim = np.zeros((self.nQ, self.nQ, self.nS))
            for i in range(self.nS):
                tmp_layer_sample = layer_hs_dimred[:, i, :]
                tmp_layer_sample_sim = tmp_layer_sample @ tmp_layer_sample.T
                layer_hs_dimred_sim[:, :, i] = tmp_layer_sample_sim
            self.layer_hs_dimred_sim_avg = layer_hs_dimred_sim.mean(axis=2)

    def get_best_layer_preds(self):
        dim_red_methods = {'pca': PCA(n_components=self.nC), 'kpca': KernelPCA(n_components=self.nC, kernel='cosine'),
                           'umap': umap.UMAP(n_components=self.nC, metric='correlation')}
        if self.store_scores is not None:
            self.store_scores_foldavg = self.store_scores.groupby(['layer', 'model'])['r^2'].mean().reset_index()
            self.best_layer_spec = self.store_scores_foldavg.loc[np.argmax(self.store_scores_foldavg['r^2'])]
            self.best_layer_hs_cat = self.h_states_cat[:, self.best_layer_spec['layer'], :]
        if self.best_layer_hs_cat is not None:
            np.random.shuffle(self.ids)
            train_len = int(self.best_layer_hs_cat.shape[0] * 0.8)
            train_ids = self.ids[:train_len]
            test_ids = self.ids[train_len:]

            x_train_best_layer, y_train_best_layer = self.best_layer_hs_cat[train_ids], self.responses[train_ids]
            x_test_best_layer, self.y_test_best_layer = self.best_layer_hs_cat[test_ids], self.responses[test_ids]
            if self.best_layer_spec['model'] in dim_red_methods.keys():
                dimred = dim_red_methods[self.best_layer_spec['model']]
                x_train_best_layer_dimred = dimred.fit_transform(x_train_best_layer)
                x_test_best_layer_dimred = dimred.transform(x_test_best_layer)
                reg_best_layer_dimred = LinearRegression().fit(x_train_best_layer_dimred, y_train_best_layer)
                self.reg_best_layer_dimred_preds = reg_best_layer_dimred.predict(x_test_best_layer_dimred)
                self.reg_best_layer_dimred_score = reg_best_layer_dimred.score(x_test_best_layer_dimred,
                                                                               self.y_test_best_layer)
        else:
            print('Best layer not found')


class ComputeFunctions:
    def __init__(self, responses_config):
        self._responses_config = responses_config

        self.df_wide = None
        self.df_wide_std = None
        self.df_long = None
        self.totals = None
        self.counts = None
        self.cov = None
        self.cov_std = None
        self.ca = None
        self.fa_matrix = None
        self.fa_rot_matrix = None
        self.f_rot_variances = None
        self.fs_rot_weights = None
        self.f_variances = None
        self.fs_weights = None
        self.fs_rot_scores = None
        self.fs_scores = None

        # self.rotation_type = 'varimax'
        self.rotation_type = 'promax'
        if self._responses_config.n_factors == 1:
            self.rotation_type = None
        self.std_issue = False

    def compute_cov_std(self):
        # try:
        self.df_wide_std = pd.DataFrame(StandardScaler().fit_transform(self.df_wide))
        tmp_bool = self.df_wide_std == 0
        if tmp_bool.all().any():
            self.df_wide_std.loc[:, tmp_bool.all()] = np.random.normal(loc=0, scale=0.001,
                                                                       size=self.df_wide_std.loc[:,
                                                                            tmp_bool.all()].shape)
            print(f"{self._responses_config.qs_name}, {self._responses_config.context_instance}")
            print('\tSubstituted all zero-columns with normal random vector with mean 0 and scale 0.001')
        self.cov_std = self.df_wide_std.cov()
        self.std_issue = False
        # except:
        #     self.std_issue = True
        #     print('Standardisation issue')

    def factor_analysis_fn(self, df, rotation_type=None):
        tmp_fa = FactorAnalyzer(rotation=rotation_type, n_factors=self._responses_config.n_factors)
        tmp_fa.fit(df)
        loadings = tmp_fa.loadings_
        f_variances = tmp_fa.get_factor_variance()

        f_psi_inv = np.diag(1 / tmp_fa.get_uniquenesses())
        fs_weights = (np.linalg.inv(loadings.T @ f_psi_inv @ loadings) @ loadings.T @ f_psi_inv).T
        fs_scores = df @ fs_weights
        return [loadings, fs_scores, fs_weights, f_variances]

    def compute_fa_matrices(self):
        self.fa_matrix, self.fs_scores, self.fs_weights, self.f_variances = self.factor_analysis_fn(self.df_wide_std)
        self.fa_rot_matrix, self.fs_rot_scores, self.fs_rot_weights, self.f_rot_variances = self.factor_analysis_fn(
            self.df_wide_std, self.rotation_type)

    def compute_ca(self):
        if len(self.cov_std) > 0:
            try:
                self.ca = np.round(pg.cronbach_alpha(self.df_wide)[0], 2)
            except:
                self.ca = np.nan()
                print('Cronbach_alpha issue')


class DatasetObject(ComputeFunctions):
    '''


    '''

    def __init__(self, paths, responses_config):
        super().__init__(responses_config)
        self._paths = paths
        self.stats = None
        # self.fail_ratio = None
        # self.time = None
        # self.ranges = None
        self.sample_size = None

    def get_dfs(self):
        if os.path.exists(self._responses_config.responses_path):
            csv_files = [self._responses_config.responses_path + el for el in
                         os.listdir(self._responses_config.responses_path) if
                         self._responses_config.dataset_name + '.csv' in el]
            if len(csv_files) > 0:
                self.df_wide = pd.read_csv(csv_files[0])
                self.df_wide.columns = self.df_wide.columns.astype(int)
                self.df_long = pd.melt(self.df_wide, ignore_index=False, var_name='question', value_name='score')
                self.df_long.insert(0, 'sample_ts', self.df_long.index)
            else:
                print('no files found for this configuration')
        else:
            print('no files found for this configuration')

    def return_objects(self, mode='run', save_objects=False):
        '''
        For model/dataset (and prompt/hyper-param setting):
            - Create an objects with long, wide DFs, totals, counts, stats, covariances, cronbach-alphas, time to sample,
            failure ratio and factor loading matrices
            - Save/load pickle with the objects
        :param mode:
        :param save_objects:
        :return:
        '''
        if mode == 'run':
            self.get_dfs()
            if len(self.df_long) > 0:
                self.totals = self.df_long.groupby(['sample_ts'], as_index=False).sum().drop(
                    columns=['question']).rename(columns={'score': 'total'})[['sample_ts', 'total']]
                self.totals.insert(self.totals.shape[1], 'source', self._responses_config.source_name)
                self.counts = self.df_long.groupby(['question', 'score'], as_index=False).count().rename(
                    columns={'sample_ts': 'count'}).pivot(index=['question'], columns='score', values='count').T
                self.stats = self.totals['total'].agg(
                    ['mean', 'median', 'count', 'std', q3, q1, my_iqr, 'skew', 'kurt']).rename({'count': 'sample_size'})
                self.sample_size = 'N=' + str(int(self.stats['sample_size']))
                self.totals.insert(self.totals.shape[1], 'source_wn',
                                   self._responses_config.source_name + ' ' + self.sample_size)
                self.cov = self.df_wide.cov()
                self.compute_cov_std()
                self.compute_ca()
                # self.get_fail_ratio()
                # self.get_time()
                self.compute_fa_matrices()
                # if save_objects:
                #     # implement saving
                #     tmp_dict = {'df_long': self.df_long, 'df_wide': self.df_wide, 'df_wide_std': self.df_wide_std,
                #                 'totals': self.totals, 'counts': self.counts, 'stats': self.stats, 'cov': self.cov,
                #                 'cov_std': self.cov_std, 'ca': self.ca, 'fail_ratio': self.fail_ratio,
                #                 'time': self.time, 'fa_matrix': self.fa_matrix, 'f_variances': self.f_variances,
                #                 'fs_scores': self.fs_scores, 'fs_weights': self.fs_weights,
                #                 'fa_rot_matrix': self.fa_rot_matrix, 'f_rot_variances': self.f_rot_variances,
                #                 'fs_rot_scores': self.fs_rot_scores, 'fs_rot_weights': self.fs_rot_weights,
                #                 'sample_size': self.sample_size}
                #     with open(self._responses_config.objects_save_path + self._responses_config.objects_save_fname,
                #               'wb') as f:
                #         pickle.dump(tmp_dict, f)
            else:
                print('Cannot compute - No data for this configuration')
        # else:
        #     if os.path.exists(self._responses_config.objects_save_path + self._responses_config.objects_save_fname):
        #         with open(self._responses_config.objects_save_path + self._responses_config.objects_save_fname,
        #                   'rb') as f:
        #             tmp_dict = pickle.load(f)
        #         for k, v in tmp_dict.items():
        #             setattr(self, k, v)
        #     else:
        #         print('Cannot load - No data for this configuration')

    def split_data(self, ratio=0.85, load=True, save=False):
        train_fname = f"{self._paths.dataset_objects_save_dir}{self._responses_config.dataset_name}_train_{ratio}.csv"
        test_fname = f"{self._paths.dataset_objects_save_dir}{self._responses_config.dataset_name}_test_{ratio}.csv"
        if not load:
            N = int(self.stats['sample_size'])
            train_len = int(N * ratio)
            test_len = N - train_len

            idx = np.random.permutation(N)
            self.train_idx = idx[:train_len]
            self.test_idx = idx[train_len:]

            self.df_wide_train = self.df_wide.iloc[self.train_idx, :]
            self.df_wide_test = self.df_wide.iloc[self.test_idx, :]
            if save:
                self.df_wide_train.to_csv(train_fname)
                self.df_wide_test.to_csv(test_fname)
        else:
            if os.path.exists(train_fname):
                self.df_wide_train = pd.read_csv(train_fname, index_col=0)
                self.train_idx = list(self.df_wide_train.index)
            if os.path.exists(test_fname):
                self.df_wide_test = pd.read_csv(test_fname, index_col=0)
                self.test_idx = list(self.df_wide_test.index)

    def create_conv(self, questions, messages, tokenizer, sample_config):
        # Training set
        self.messages_train = [messages] * len(self.train_idx)
        self.answers_train = [[inv_lists[self._responses_config.qs_name][a] for a in row] for r, row in
                              self.df_wide_train.iterrows()]

        # answer to the first question
        self.messages_train = [messages] * len(self.train_idx)
        self.messages_train = [message + [{'role': 'assistant', 'content': ans[0]}] for message, ans in
                               zip(self.messages_train, self.answers_train)]

        # append question and response (from the dataset) to create multiple conversations
        for i, qs in enumerate(questions):
            if i > 0:
                self.messages_train = [message + [{'role': 'user', 'content': qs}] for message in self.messages_train]
                self.messages_train = [message + [{'role': 'assistant', 'content': ans[i]}] for message, ans in
                                       zip(self.messages_train, self.answers_train)]

        self.messages_formatted_train = [format_messages(message, tokenizer, sample_config) for message in
                                         self.messages_train]

        # Test set
        self.messages_test = [messages] * len(self.test_idx)
        self.answers_test = [[inv_lists[self._responses_config.qs_name][a] for a in row] for r, row in
                             self.df_wide_test.iterrows()]

        # answer to the first question
        self.messages_test = [messages] * len(self.test_idx)
        self.messages_test = [message + [{'role': 'assistant', 'content': ans[0]}] for message, ans in
                              zip(self.messages_test, self.answers_test)]

        # append question and response (from the dataset) to create multiple conversations
        for i, qs in enumerate(questions):
            if i > 0:
                self.messages_test = [message + [{'role': 'user', 'content': qs}] for message in self.messages_test]
                self.messages_test = [message + [{'role': 'assistant', 'content': ans[i]}] for message, ans in
                                      zip(self.messages_test, self.answers_test)]

        self.messages_formatted_test = [format_messages(message, tokenizer, sample_config) for message in
                                        self.messages_test]


class ResponsesObject(ComputeFunctions):
    def __init__(self, paths, responses_config):
        super().__init__(responses_config)
        self._paths = paths
        self.stats = None
        self.fail_ratio = None
        self.time = None
        self.ranges = None
        self.sample_size = None

    def get_dfs(self):
        '''
        - Either (model): load all the response dataframes from the model into long format and then pivots to wide
        - Or (dateset): load dataset wide format and get into long

        :return:
        '''
        if os.path.exists(self._responses_config.responses_path):
            csv_files = [self._responses_config.responses_path + el for el in
                         os.listdir(self._responses_config.responses_path) if '.csv' in el]
            if len(csv_files) > 0:
                self.df_long = pd.concat([pd.read_csv(el) for el in csv_files], ignore_index=True)[
                    ['sample_ts', 'question', 'score', 'temp', 'top_p']]
                self.df_long[['question', 'score']] = self.df_long[['question', 'score']].astype(int)
                self.df_wide = self.df_long.pivot(index='sample_ts', columns='question',
                                                  values='score').reset_index(drop=True)

            else:
                print('no files found for this configuration')
        else:
            print('no files found for this configuration')

    # def get_fail_ratio(self):
    #     fail_ratio_file = self._responses_config.results_path + 'fail_ratio_' + self._responses_config.model_name_l + '.csv'
    #     if os.path.exists(fail_ratio_file):
    #         self.fail_ratio = pd.read_csv(fail_ratio_file)
    #         self.fail_ratio = pd.melt(self.fail_ratio, id_vars=['temp_val'], var_name='top_p').rename(
    #             columns={'temp_val': 'temp'}).astype(float)
    #         self.fail_ratio = self.fail_ratio[(self.fail_ratio['top_p'] == self._responses_config.top_p) & (
    #                 self.fail_ratio['temp'] == self._responses_config.temp)]['value'].values[0]

    # def get_time(self):
    #     if self._responses_config.type == 'model' or self._responses_config.type == 'context':
    #         time_file = self._responses_config.results_path + 'time_sum_total_' + self._responses_config.model_name_l + '.csv'
    #         if os.path.exists(time_file):
    #             self.time = pd.read_csv(time_file)
    #             self.time = self.time[(self.time['top_p'] == self._responses_config.top_p) & (
    #                     self.time['temp'] == self._responses_config.temp)]['q1'].values[0]
    #     else:
    #         pass

    def return_objects(self, mode='run', save_objects=False):
        '''
        For model/dataset (and prompt/hyper-param setting):
            - Create an objects with long, wide DFs, totals, counts, stats, covariances, cronbach-alphas, time to sample,
            failure ratio and factor loading matrices
            - Save/load pickle with the objects
        :param mode:
        :param save_objects:
        :return:
        '''
        if mode == 'run':
            self.get_dfs()
            if len(self.df_long) > 0:
                self.totals = self.df_long.groupby(['sample_ts'], as_index=False).sum().drop(
                    columns=['question']).rename(columns={'score': 'total'})[['sample_ts', 'total']]
                self.totals.insert(self.totals.shape[1], 'source', self._responses_config.source_name)
                self.counts = self.df_long.groupby(['question', 'score'], as_index=False).count().rename(
                    columns={'sample_ts': 'count'}).pivot(index=['question'], columns='score', values='count').T
                self.stats = self.totals['total'].agg(
                    ['mean', 'median', 'count', 'std', q3, q1, my_iqr, 'skew', 'kurt']).rename({'count': 'sample_size'})
                self.sample_size = 'N=' + str(int(self.stats['sample_size']))
                self.totals.insert(self.totals.shape[1], 'source_wn',
                                   self._responses_config.source_name + ' ' + self.sample_size)
                self.cov = self.df_wide.cov()
                self.compute_cov_std()
                self.compute_ca()
                # self.get_fail_ratio()
                # self.get_time()
                self.compute_fa_matrices()
                # if save_objects:
                #     # implement saving
                #     tmp_dict = {'df_long': self.df_long, 'df_wide': self.df_wide, 'df_wide_std': self.df_wide_std,
                #                 'totals': self.totals, 'counts': self.counts, 'stats': self.stats, 'cov': self.cov,
                #                 'cov_std': self.cov_std, 'ca': self.ca, 'fail_ratio': self.fail_ratio,
                #                 'time': self.time, 'fa_matrix': self.fa_matrix, 'f_variances': self.f_variances,
                #                 'fs_scores': self.fs_scores, 'fs_weights': self.fs_weights,
                #                 'fa_rot_matrix': self.fa_rot_matrix, 'f_rot_variances': self.f_rot_variances,
                #                 'fs_rot_scores': self.fs_rot_scores, 'fs_rot_weights': self.fs_rot_weights,
                #                 'sample_size': self.sample_size}
                #     with open(self._responses_config.objects_save_path + self._responses_config.objects_save_fname,
                #               'wb') as f:
                #         pickle.dump(tmp_dict, f)
            else:
                print('Cannot compute - No data for this configuration')
        # else:
        #     if os.path.exists(self._responses_config.objects_save_path + self._responses_config.objects_save_fname):
        #         with open(self._responses_config.objects_save_path + self._responses_config.objects_save_fname,
        #                   'rb') as f:
        #             tmp_dict = pickle.load(f)
        #         for k, v in tmp_dict.items():
        #             setattr(self, k, v)
        #     else:
        #         print('Cannot load - No data for this configuration')

    def get_ranges(self):
        if self.df_long is not None:
            self.ranges = {'cov_std': [self.cov_std.min().min(), self.cov_std.max().max()],
                           'cov': [self.cov.min().min(), self.cov.max().max()],
                           'fa': [self.fa_matrix.min().min(), self.fa_matrix.max().max()],
                           'fa_rot': [self.fa_rot_matrix.min().min(), self.fa_rot_matrix.max().max()]}


class Trainer:
    def __init__(self, trainer_set, model_instance, sample_config, toker, bools, method='gd', lr=5e-3):
        self.train_data = trainer_set['train_set']
        self.test_data = trainer_set['test_set']
        self.eval_freq = trainer_set['eval_freq']
        self.sample_config = sample_config
        self.tokenizer = toker
        self.model = model_instance
        self.parameter = None
        self.method = method
        self.lr = lr
        self.bools = bools
        self.grad_accum = None

    def get_grad_object(self):
        grad_object = None
        if self.parameter is not None:
            if self.method == 'gd':
                grad_object = self.parameter.grad
            '''
            implement other methods
            '''
            if self.method == 'top_p_gd':
                pass
            if self.method == 'top_k_svd_gd':
                pass
            if self.method == 'top_k-one_svd_gd':
                pass
        return grad_object

    def step(self):
        if self.parameter is not None:
            with torch.no_grad():
                tmp_grad_object = self.get_grad_object()
                if self.grad_accum is not None:
                    self.grad_accum += tmp_grad_object.detach().to('cpu') / len(self.train_data)
                else:
                    self.grad_accum = tmp_grad_object.detach().to('cpu') / len(self.train_data)
                self.parameter.copy_(self.parameter - self.lr * tmp_grad_object)

    def train(self, parameter, param_name):
        self.grad_accum = None
        self.parameter = parameter
        self.param_name = param_name

        # setup paths
        self.weights_path_fn()
        self.grads_path_fn()
        self.loss_path_fn()
        self.responses_path_fn()
        self.outputs_path_fn()
        self.time_path_fn()
        self.results_path_fn()
        self.loss_plot_path_fn()

        self.model.zero_grad()
        # running_loss = 0.0
        store_loss = []
        store_val_loss = []
        store_val_loss_idx = []
        val_loss_df = None
        loss_df = None
        j = 0
        if self.parameter is not None:
            for i, data in tqdm(enumerate(self.train_data)):
                text_ids = self.tokenizer(data, return_tensors='pt', padding=True).input_ids.to(
                    device_name)

                outputs = self.model(text_ids, labels=text_ids)
                loss = outputs.loss / len(text_ids)
                store_loss.append(loss.item())
                print(f"\ttrain loss: {loss}")

                loss.backward()
                self.step()
                self.model.zero_grad()
                del loss, outputs, text_ids

                # validation loss
                with torch.no_grad():
                    if ((i + 1) % self.eval_freq == 0) or (i + 1 == len(self.train_data)):
                        store_val_loss_idx.append(i)
                        text_ids = self.tokenizer(self.test_data[j], return_tensors='pt', padding=True).input_ids.to(
                            device_name)

                        outputs = self.model(text_ids, labels=text_ids)
                        val_loss = outputs.loss / len(text_ids)
                        store_val_loss.append(val_loss.item())
                        print(f"\tval loss: {val_loss}")
                        j += 1
                        del val_loss, outputs, text_ids
                        flush()

            # save losses
            loss_df = pd.DataFrame(
                {'iter': np.arange(len(store_loss)), 'loss': store_loss, 'param': self.param_name, 'lr': self.lr})
            loss_df.to_csv(f"{self.loss_path}{self.param_name}_train.csv", index=False)

            val_loss_df = pd.DataFrame(
                {'iter': store_val_loss_idx, 'loss': store_val_loss, 'param': self.param_name, 'lr': self.lr})
            val_loss_df.to_csv(f"{self.loss_path}{self.param_name}_test.csv", index=False)

            # save weights and gradients
            if self.bools.saveTensors:
                torch.save(copy.deepcopy(self.parameter),
                           f"{self.weights_path}{self.param_name}_weights_ts-{len(self.train_data)}.pt")
                torch.save(copy.deepcopy(self.grad_accum),
                           f"{self.grads_path}{self.param_name}_grad_ts-{len(self.train_data)}.pt")

                # Upload to bucket
                if self.bools.saveBucket:
                    w_fpath = f"{self.weights_path}{self.param_name}_weights.pt"
                    g_fpath = f"{self.grads_path}{self.param_name}_grad.pt"
                    w_fname = re.sub('/', '^^',
                                     f"{self.weights_path}{self.param_name}_weights_ts-{len(self.train_data)}.pt")
                    g_fname = re.sub('/', '^^', f"{self.grads_path}{self.param_name}_grad_ts-{len(self.train_data)}.pt")
                    try:
                        print('=' * 30 + '\n Uploading weights\n')
                        upload_blob_from_memory('llm-bucket-res', w_fpath, w_fname)
                        print(
                            '\nUploaded weights\n' + '=' * 30)
                    except:
                        print('Weight saving error')
                        pass
                    try:
                        print('=' * 30 + '\n Uploading gradients\n')
                        upload_blob_from_memory('llm-bucket-res', g_fpath, g_fname)
                        print(
                            '\nUploaded gradients\n' + '=' * 30)
                    except:
                        print('Grad saving error')
                        pass

        return store_loss, loss_df, store_val_loss, val_loss_df

    def grads_path_fn(self):
        self.grads_path = self.sample_config._paths.files_dir + '/'.join(
            ['grads', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
             self.method + '_' + str(self.lr)]) + '/'
        Path(self.grads_path).mkdir(parents=True, exist_ok=True)
        self.sample_config.grads_path = self.grads_path
        # return self.grads_path

    def weights_path_fn(self):
        self.weights_path = self.sample_config._paths.files_dir + '/'.join(
            ['weights', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
             self.method + '_' + str(self.lr)]) + '/'
        Path(self.weights_path).mkdir(parents=True, exist_ok=True)
        self.sample_config.weights_path = self.weights_path
        # return self.weights_path

    def loss_path_fn(self):
        self.loss_path = self.sample_config._paths.files_dir + '/'.join(
            ['loss', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
             self.method + '_' + str(self.lr)]) + '/'
        Path(self.loss_path).mkdir(parents=True, exist_ok=True)
        self.sample_config.loss_path = self.loss_path
        # return self.loss_path

    def loss_plot_path_fn(self):
        self.loss_plot_path = self.sample_config._paths.plots_dir + '/'.join(
            ['loss', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
             self.method + '_' + str(self.lr)]) + '/'
        Path(self.loss_plot_path).mkdir(parents=True, exist_ok=True)
        self.sample_config.loss_plot_path = self.loss_plot_path
        # return self.loss_path

    def responses_path_fn(self):
        self.responses_path = self.sample_config._paths.files_dir + '/'.join(
            ['responses', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
             self.sample_config.model_name_rhp,
             self.method + '_' + str(self.lr), self.param_name]) + '/'
        Path(self.responses_path).mkdir(parents=True, exist_ok=True)
        Path(self.responses_path + 'fails/').mkdir(parents=True, exist_ok=True)
        self.sample_config.responses_path = self.responses_path
        # return self.responses_path

    def outputs_path_fn(self):
        self.outputs_path = self.sample_config._paths.files_dir + '/'.join(
            ['outputs', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
             self.sample_config.model_name_rhp,
             self.method + '_' + str(self.lr), self.param_name]) + '/'
        Path(self.outputs_path).mkdir(parents=True, exist_ok=True)
        Path(self.outputs_path + 'fails/').mkdir(parents=True, exist_ok=True)
        self.sample_config.outputs_path = self.outputs_path
        # return self.outputs_path

    def time_path_fn(self):
        self.time_path = self.sample_config._paths.files_dir + '/'.join(
            ['time', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
             self.sample_config.model_name_rhp,
             self.method + '_' + str(self.lr), self.param_name]) + '/'
        Path(self.time_path).mkdir(parents=True, exist_ok=True)
        Path(self.time_path + 'fails/').mkdir(parents=True, exist_ok=True)
        self.sample_config.time_path = self.time_path
        # return self.time_path

    def results_path_fn(self):
        self.results_path = self.sample_config._paths.files_dir + '/'.join(
            ['results', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
             self.method + '_' + str(self.lr), self.param_name]) + '/'
        Path(self.results_path).mkdir(parents=True, exist_ok=True)
        self.sample_config.results_path = self.results_path
        # return self.results_path

    # @property
    # def weights_path(self):
    #     self._weights_path = self.sample_config._paths.files_dir + '/'.join(
    #         ['weights', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
    #          self.method+'_'+str(self.lr)]) + '/'
    #     Path(self._weights_path).mkdir(parents=True, exist_ok=True)
    #     self.sample_config.weights_path = self._weights_path
    #     return self._weights_path
    #
    # @property
    # def loss_path(self):
    #     self._loss_path = self.sample_config._paths.files_dir + '/'.join(
    #         ['loss', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
    #          self.method+'_'+str(self.lr)]) + '/'
    #     Path(self._loss_path).mkdir(parents=True, exist_ok=True)
    #     self.sample_config.loss_path = self._loss_path
    #     return self._loss_path
    #
    # @property
    # def responses_path(self):
    #     self._responses_path = self.sample_config._paths.files_dir + '/'.join(
    #         ['responses', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
    #          self.sample_config.model_name_rhp,
    #          self.method+'_'+str(self.lr)]) + '/'
    #     Path(self._responses_path).mkdir(parents=True, exist_ok=True)
    #     Path(self._responses_path + 'fails/').mkdir(parents=True, exist_ok=True)
    #     self.sample_config.responses_path = self._responses_path
    #     return self._responses_path
    #
    # @property
    # def outputs_path(self):
    #     self._outputs_path = self.sample_config._paths.files_dir + '/'.join(
    #         ['outputs', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
    #          self.sample_config.model_name_rhp,
    #          self.method+'_'+str(self.lr)]) + '/'
    #     Path(self._outputs_path).mkdir(parents=True, exist_ok=True)
    #     Path(self._outputs_path + 'fails/').mkdir(parents=True, exist_ok=True)
    #     self.sample_config.outputs_path = self._outputs_path
    #     return self._outputs_path
    #
    # @property
    # def time_path(self):
    #     self._time_path = self.sample_config._paths.files_dir + '/'.join(
    #         ['time', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
    #          self.sample_config.model_name_rhp,
    #          self.method+'_'+str(self.lr)]) + '/'
    #     Path(self._time_path).mkdir(parents=True, exist_ok=True)
    #     Path(self._time_path + 'fails/').mkdir(parents=True, exist_ok=True)
    #     self.sample_config.time_path = self._time_path
    #     return self._time_path
    #
    # @property
    # def results_path(self):
    #     self._results_path = self.sample_config._paths.files_dir + '/'.join(
    #         ['results', self.sample_config.qs_name, self.sample_config.context_path, self.sample_config.model_name_rp,
    #          self.method+'_'+str(self.lr)]) + '/'
    #     Path(self._results_path).mkdir(parents=True, exist_ok=True)
    #     self.sample_config.results_path = self._results_path
    #     return self._results_path
