import spgl1
import os
import pandas as pd
import torch, einops
import numpy as np
from torch import nn
import torch.nn.functional as F
from natsort import natsorted

# from sae_model_perturb import latent_qs_perturbed

torch.set_grad_enabled(False)


def rescale_np(ref, x):
    to_max = ref.max(dim=0).values
    to_min = ref.min(dim=0).values
    x_min = x.min(dim=0).values
    x_max = x.max(dim=0).values
    # to_max = ref.max(axis=0)
    # to_min = ref.min(axis=0)
    # x_min = x.min(axis=0)
    # x_max = x.max(axis=0)

    return (to_max - to_min) * (x - x_min) / (x_max - x_min + 1e-8) + to_min


def get_perturb_delta(model_sae, deltaS_fp, q_idx_to_perturb=1, score_change=1, saveDelta=False,
                      loadDelta=True, device_name='cpu'):
    if loadDelta:
        deltaS = torch.load(deltaS_fp,map_location=torch.device(device_name)).to(device_name)
    else:
        W_proj = model_sae.W_proj.to('cpu').detach().numpy()
        delta_q = np.zeros(model_sae.W_proj.shape[1])
        delta_q[q_idx_to_perturb] = score_change
        deltaS, resid, grad, info = spgl1.spg_bp(W_proj.T, delta_q, verbosity=0, iter_lim=5000)
        # deltaS, resid, grad, info = spgl1.spg_bpdn(W_proj.T, delta_q, verbosity=0, iter_lim=5000,sigma=0)
        solver_status = info['stat']
        if not solver_status:
            print(f'Solver status: {solver_status == 1}, for: {deltaS}')
        deltaS = torch.tensor(deltaS).type(torch.float32).to(device_name)
        if saveDelta:
            torch.save(deltaS, deltaS_fp)

    return deltaS


def load_pert_responses(paths, sample_config, bools,ds_type):


    rsp_cols = ['sub', 'sample_ts', 'question', 'score', 'model', 'latent_q_score_change', 'latent_p_force',
                'q_to_perturb', 'hs_steer_mlt', 'instr_name', 'qs', 'label_perm', 'nSamples']
    rsp_gr_cols = ['sub', 'question', 'model', 'latent_q_score_change', 'latent_p_force', 'q_to_perturb',
                   'hs_steer_mlt', 'instr_name', 'qs', 'label_perm', 'nSamples']
    rsp_gr_org_cols = ['sub', 'question', 'model', 'instr_name', 'qs', 'label_perm', 'nSamples']
    pert_cols = ['latent_q_score_change', 'latent_p_force', 'q_to_perturb', 'hs_steer_mlt']

    if not bools.loadMe:
        subs = natsorted([d for d in os.listdir(responses_path) if '.DS_Store' not in d])

        original_logits_path_dir = f"{paths.original_logits_path}/{''.join(sample_config.label_letters)}"
        original_logits_csv = [f for f in os.listdir(original_logits_path_dir) if
                               ('.csv' in f) and (sample_config.model_name in f)]
        if len(original_logits_csv) > 0:
            original_logit_responses = pd.read_csv(f'{original_logits_path_dir}/{original_logits_csv[0]}')
            original_logit_responses_avg = original_logit_responses.groupby(rsp_gr_org_cols, as_index=False)['score'].mean()
            original_logit_responses_avg[pert_cols] = 0
            original_logit_responses_avg = original_logit_responses_avg[original_logit_responses_avg['sub'].isin(subs)]

        perturbed_logit_responses = []
        for sub in subs:
            sub_path = f"{responses_path}/{sub}"
            qs = natsorted([d for d in os.listdir(sub_path) if '.DS_Store' not in d])
            for q in qs:
                sub_q_path = f"{sub_path}/{q}/"
                spec = [d for d in os.listdir(sub_q_path) if '.DS_Store' not in d and sample_config.model_name_rp in d]
                if len(spec) > 0:
                    spec = spec[0]
                    csv_files = [f for f in os.listdir(f"{sub_q_path}/{spec}") if '.csv' in f and 'responses' in f]
                    if len(csv_files) > 0:
                        for csv_file in csv_files:
                            tmp_df = pd.read_csv(f"{sub_q_path}/{spec}/{csv_file}")[rsp_cols]
                            tmp_df = tmp_df.groupby(rsp_gr_cols, as_index=False)['score'].mean()
                            perturbed_logit_responses.append(tmp_df)
        perturbed_logit_responses = pd.concat(perturbed_logit_responses)
        perturbed_logit_responses = perturbed_logit_responses[perturbed_logit_responses['hs_steer_mlt']!=0]

        logits_responses = pd.concat([original_logit_responses_avg, perturbed_logit_responses]).sort_values(
            ['sub', 'question', 'hs_steer_mlt']).reset_index(drop=True)
        logits_responses['hs_steer_mlt'] = logits_responses['hs_steer_mlt'].apply(lambda x: 'steer_mlt_' + str(x))
        logits_responses_wide = pd.pivot(logits_responses, index=rsp_gr_org_cols, columns='hs_steer_mlt',
                                         values='score').reset_index()
        steer_cols = [c for c in logits_responses_wide.columns if 'steer_mlt' in c and c != 'steer_mlt_0.0']
        # steer_cols = [c for c in logits_responses_wide.columns if 'steer_mlt' in c and c != 'steer_mlt_0']
        for steer_col in steer_cols:
            logits_responses_wide['diff_' + steer_col] = logits_responses_wide[steer_col] - logits_responses_wide[
                'steer_mlt_0.0']
            # logits_responses_wide['diff_' + steer_col] = logits_responses_wide[steer_col] - logits_responses_wide[
            #     'steer_mlt_0']

        if bools.saveMe:
            logits_responses_wide.to_csv(f"{save_path}logits_base_and_pert_responses_wide_{ds_type}_{sample_config.model_name}.csv", index=False)
            logits_responses.to_csv(f"{save_path}logits_base_and_pert_responses_{ds_type}_{sample_config.model_name}.csv", index=False)
    else:
        logits_responses_wide = pd.read_csv(f"{save_path}logits_base_and_pert_responses_wide_{ds_type}_{sample_config.model_name}.csv")
        logits_responses = pd.read_csv(f"{save_path}logits_base_and_pert_responses_{ds_type}_{sample_config.model_name}.csv")

    return logits_responses_wide, logits_responses

def load_pert_logits(paths, sample_config,bools,ds_type):


    # responses_path = f"{paths.files_dir}paired/{''.join(sample_config.label_letters)}/subjects"
    # save_path = f"{paths.files_dir}paired/{''.join(sample_config.label_letters)}/"

    rsp_cols = ['sub', 'sample_ts', 'question', 'score', 'model', 'latent_q_score_change', 'latent_p_force',
                'q_to_perturb', 'hs_steer_mlt', 'instr_name', 'qs', 'label_perm', 'nSamples']
    rsp_gr_cols = ['sub', 'question', 'model', 'latent_q_score_change', 'latent_p_force', 'q_to_perturb',
                   'hs_steer_mlt', 'instr_name', 'qs', 'label_perm', 'nSamples']
    rsp_gr_org_cols = ['sub', 'question', 'model', 'instr_name', 'qs', 'label_perm', 'nSamples']
    pert_cols = ['latent_q_score_change', 'latent_p_force', 'q_to_perturb', 'hs_steer_mlt']

    if not bools.loadMe:
        subs = natsorted([d for d in os.listdir(paths.responses_path) if '.DS_Store' not in d])
        original_logits_path_dir = f"{paths.files_path_o}{''.join(sample_config.label_letters)}/subjects"


        perturbed_logit_responses = []
        for sub in subs:
            sub_path = f"{paths.responses_path}/{sub}"
            qs = natsorted([d for d in os.listdir(sub_path) if '.DS_Store' not in d])
            for q in qs:
                sub_q_path = f"{sub_path}/{q}/"
                spec = [d for d in os.listdir(sub_q_path) if '.DS_Store' not in d and sample_config.model_name_rp in d]
                if len(spec) > 0:
                    spec = spec[0]
                    csv_files = [f for f in os.listdir(f"{sub_q_path}/{spec}") if '.csv' in f and 'logits' in f]
                    if len(csv_files) > 0:
                        for csv_file in csv_files:
                            tmp_df = pd.read_csv(f"{sub_q_path}/{spec}/{csv_file}")
                            tmp_expected_score = tmp_df['probs'].values.T@tmp_df['label_scores']
                            tmp_exp_df = tmp_df[rsp_gr_cols].head(1)
                            tmp_exp_df['score'] = tmp_expected_score
                            # tmp_df = tmp_df.groupby(rsp_gr_cols, as_index=False)['score'].mean()
                            perturbed_logit_responses.append(tmp_exp_df)
        perturbed_logit_responses = pd.concat(perturbed_logit_responses)
        perturbed_logit_responses = perturbed_logit_responses[perturbed_logit_responses['hs_steer_mlt']!=0]

        original_logit_responses = []
        for sub in subs:
            sub_path = f"{original_logits_path_dir}/{sub}"
            qs = natsorted([d for d in os.listdir(sub_path) if '.DS_Store' not in d])
            for q in qs:
                sub_q_path = f"{sub_path}/{q}/"
                spec = [d for d in os.listdir(sub_q_path) if '.DS_Store' not in d and sample_config.model_name_rp in d]
                if len(spec) > 0:
                    spec = spec[0]
                    csv_files = [f for f in os.listdir(f"{sub_q_path}/{spec}") if '.csv' in f and 'logits' in f]
                    if len(csv_files) > 0:
                        for csv_file in csv_files:
                            tmp_df = pd.read_csv(f"{sub_q_path}/{spec}/{csv_file}")
                            tmp_expected_score = tmp_df['probs'].values.T@tmp_df['label_scores']
                            tmp_exp_df = tmp_df[rsp_gr_org_cols].head(1)
                            tmp_exp_df['score'] = tmp_expected_score
                            # tmp_df = tmp_df.groupby(rsp_gr_cols, as_index=False)['score'].mean()
                            original_logit_responses.append(tmp_exp_df)
        original_logit_responses = pd.concat(original_logit_responses)
        original_logit_responses[pert_cols]=0



        logits_responses = pd.concat([original_logit_responses, perturbed_logit_responses]).sort_values(
            ['sub', 'question', 'hs_steer_mlt']).reset_index(drop=True)
        logits_responses['hs_steer_mlt'] = logits_responses['hs_steer_mlt'].apply(lambda x: 'steer_mlt_' + str(x))
        logits_responses_wide = pd.pivot(logits_responses, index=rsp_gr_org_cols, columns='hs_steer_mlt',
                                         values='score').reset_index()
        steer_cols = [c for c in logits_responses_wide.columns if 'steer_mlt' in c and c != 'steer_mlt_0.0']
        # steer_cols = [c for c in logits_responses_wide.columns if 'steer_mlt' in c and c != 'steer_mlt_0']
        for steer_col in steer_cols:
            logits_responses_wide['diff_' + steer_col] = logits_responses_wide[steer_col] - logits_responses_wide[
                'steer_mlt_0.0']
            # logits_responses_wide['diff_' + steer_col] = logits_responses_wide[steer_col] - logits_responses_wide[
            #     'steer_mlt_0']

        if bools.saveMe:
            logits_responses_wide.to_csv(f"{paths.save_path}logits_base_and_pert_exp_scores_wide_{ds_type}_{sample_config.model_name}.csv", index=False)
            logits_responses.to_csv(f"{paths.save_path}logits_base_and_pert_exp_scores_{ds_type}_{sample_config.model_name}.csv", index=False)
    else:
        logits_responses_wide = pd.read_csv(f"{paths.save_path}logits_base_and_pert_exp_scores_wide_{ds_type}_{sample_config.model_name}.csv")
        logits_responses = pd.read_csv(f"{paths.save_path}logits_base_and_pert_exp_scores_{ds_type}_{sample_config.model_name}.csv")

    return logits_responses_wide, logits_responses

