import os.path
from data_generation.synthetic_data_generation_PrefShap import *
import os
import sys
import numpy as np
import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
from utils.utils import *
from pref_shap.pref_shap import *
from pref_shap.rkhs_shap import *
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import shutil
import dill
from sklearn.linear_model import Ridge
sns.set()

dim = int(sys.argv[1])

def cumsum_thingy_2(cumsum_indices,shapley_vals):
    if cumsum_indices:
        cat_parts = []
        for i in range(len(cumsum_indices)-1):
            part = shapley_vals[:,cumsum_indices[i]:cumsum_indices[i+1]].sum(1,keepdim=True)
            cat_parts.append(part)
        p_output = torch.cat(cat_parts,dim=1)
        return p_output
    else:
        return shapley_vals

def return_feature_names(job):
    if job in ['pokemon','pokemon_wl']:
        l1= [1,1,1,1,1,1,1,19]
        l1.insert(0,0)
        l1=np.cumsum(l1).tolist()
        # l1=['HP', 'Attack', 'Defense', 'Sp. Atk', 'Sp. Def', 'Speed', 'Legendary',
       # 'Bug', 'Dark', 'Dragon', 'Electric', 'Fairy', 'Fighting', 'Fire',
       # 'Flying', 'Ghost', 'Grass', 'Ground', 'Ice', 'Normal', 'Poison',
       # 'Psychic', 'Rock', 'Steel', 'Water']
        l2=['HP', 'Attack', 'Defense', 'Sp. Atk', 'Sp. Def', 'Speed', 'Legendary','Type']
        coeffs= 10**np.linspace(-7,-2,0)
        return l1,l2,True,coeffs

    if job in ['chameleon','chameleon_wl']:
        l2 = ['ch.res', 'jl.res', 'tl.res', 'mass.res', 'SVL', 'prop.main',
       'prop.patch']
        coeffs= 10**np.linspace(-7,-2,0)
        return [],l2,False,coeffs

    if job  == 'alan_data_5000_100':
        D=3
        a = np.zeros((D, D))
        zip1,zip2 = np.triu_indices(D, 1)

        # l2 = ['within_cluster'] + [f'feature_{i}_{j}' for (i,j) in zip(zip1,zip2)] + [f'indicator {d}' for d in range(D)]
        l2 = [r'$x^{[0]}$'] + [r'$x^{AB}$',r'$x^{AC}$',r'$x^{BC}$'] + ['in A','in B','in C']
        coeffs= 10**np.linspace(-7,-2,0)
        return [],l2,False,coeffs

    if job =='alan_data_5000_1000':
        D=3
        a = np.zeros((D, D))
        zip1,zip2 = np.triu_indices(D, 1)

        l2 = [r'$x^{[0]}$'] + [r'$x^{AB}$',r'$x^{AC}$',r'$x^{BC}$'] + ['in A','in B','in C']
        coeffs= 10**np.linspace(-7,-2,0)
        return [],l2,False,coeffs

    if job in ['alan_data_5000_1000_10_10','toy_data_5000_4_2']:
        l2 = [f'important_{i}' for i in range(1,3)] + [f'Unimportant_{i}' for i in range(3,5)]
        coeffs= 10**np.linspace(-7,-2,0)
        return [],l2,False,coeffs
    
    if job in [f'synthetic_data_{int(sys.argv[2])}_0']:
        l2 = [f'feature_{i}' for i in range(1,dim+1)]
        coeffs= 10**np.linspace(-7,-2,0)
        return [],l2,False,coeffs
    
    if int(sys.argv[3]) == 0:
        job_name = f'synthetic_data_{int(sys.argv[2])}_0'
    else:
        job_name = f'synthetic_block_data_{int(sys.argv[2])}_0'

    if job in [job_name]:
    #if job in [f'synthetic_block_data_{int(sys.argv[2])}_0']:
        l2 = [f'feature_{i}' for i in range(1,dim+1)]
        coeffs= 10**np.linspace(-7,-2,0)
        return [],l2,False,coeffs

    if job in ['synthetic_data_sigmoid_4950_0']:
        l2 = [f'feature_{i}' for i in range(1,dim+1)]
        coeffs= 10**np.linspace(-7,-2,0)
        return [],l2,False,coeffs

    d=[5,3]
    if job in [f'hard_data_10000_1000_{d[0]}_{d[1]}']:
        l2 = [f'important_{i}' for i in range(1,d[1]+1)] + [f'Unimportant_{i}' for i in range(d[1]+1,d[0]+1)]
        coeffs= 10**np.linspace(-7,-2,0)
        return [],l2,False,coeffs


def cumsum_thingy(cumsum_indices,shapley_vals):
    cat_parts = []
    for i in range(len(cumsum_indices)-1):
        part = shapley_vals[cumsum_indices[i]:cumsum_indices[i+1],:].sum(0,keepdim=True)
        cat_parts.append(part)
    p_output = torch.cat(cat_parts,dim=0)
    return p_output
    
def sigmoid(l, r, dim):

    #Sigmoid of l*A*r

    block = np.array([[0,-1],[1,0]])

    A = block_diagonal_matrix(block, int(dim/2))
    #print("A = ", A)

    l2 = np.dot(l,A)
    
    #print("l2 = ", l2)
    
    #print("r = ", r)

    x = np.dot(l2,r)

    #print("skew symmetric function g = ", x)

    val = 1/(1 + np.exp(-x))

    #print("sigmoid val = ", val)
    
    return x

def block_diagonal_matrix(block, num_blocks):
# Create a block diagonal matrix with the specified block repeated
    return np.block([[block if i == j else np.zeros_like(block) for j in range(num_blocks)] for i in range(num_blocks)])

def transformed_features(l, r, dim):
    t = np.zeros((2,2))
    t[0, 1] = -1
    t[1, 0] = 1
    block = np.array([[0,-1],[1,0]])

    #print("dim = ", dim)

    A = block_diagonal_matrix(block, int(dim/2))

    if not isinstance(A, torch.Tensor):
        A = torch.tensor(A, dtype=torch.float32)

    device = l.device
    A = A.to(device)

    if not isinstance(l, torch.Tensor):
        l = l.clone().detach()
    if not isinstance(r, torch.Tensor):
        r = r.clone().detach()

    #print("left = ", l, l.shape)

    l2 = torch.matmul(l, A)

    #print("left = ", l2, l2.shape)

    #print("right = ", r, r.shape)

    x = torch.sum(l2 * r , dim = 1)

    A2 = np.zeros((dim,dim))
    #print("A2 shape = ", A2.shape)
    tf_length = math.comb(dim,2) #transformed feature length
    tf = np.zeros((l.shape[0],tf_length))
    tf = torch.tensor(tf, dtype=torch.float32)
    tf = tf.to(device)

    w = 0
    for i in range(dim):
        for j in range(i+1, dim):
            A2[i,j] = -1
            A2[j,i] = 1
            A2 = torch.tensor(A2, dtype=torch.float32)
            A2 = A2.to(device)
            l2 = torch.matmul(l, A2)
            x = torch.sum(l2 * r , dim = 1)
            #print("i, j = ", i, j)

            tf[:,w] = x

            w = w+1
            A2[i,j] = 0
            A2[j,i] = 0

    return tf


def get_shapley_vals(job,model_string,fold,block,train_params,num_matches,post_method,interventional,hello):
    with open( f'{job}_results_{block}/{model_string}/run_{fold}.pickle' , 'rb') as handle:
        loaded_model = pickle.load(handle)
    best_model = dill.loads(loaded_model)
    ls = best_model['ls']
    alpha = best_model['alpha'].float()
    ind_points = best_model['inducing_points'].float()
    
    #print("ind points = ", ind_points)
    print("block = ", block)
    model = best_model['model'].to('cuda:0')
    c = train_GP(block, train_params=train_params, flag = True)
    #print("c.X_tr = ",c.X_tr)
    print("c.X_tr shape = ",c.X_tr.shape)
    #print("c.S = ", c.S)
    c.load_and_split_data(True)
    if model_string=='SGD_base':
        inner_kernel=RBF_multiple_ls(d=ind_points.shape[1])
        inner_kernel._set_lengthscale(ls)
        inner_kernel=inner_kernel.to('cuda:0')
        alpha=alpha.to('cuda:0')
        ps = rkhs_shap(model=model, alpha=alpha, k=inner_kernel, X=ind_points, max_S=2500,
                       rff_mode=False, eps=1e-3, cg_max_its=10, lamb=1e-3, max_inv_row=0, cg_bs=25, post_method=post_method,
                       interventional=interventional, device='cuda:0')
        x = torch.from_numpy(c.X_val).float()
        if x.shape[0]<100:
            x = torch.from_numpy(c.X_tr).float()
        Y_target, weights, Z = ps.fit(x)
        cooking_dict = {'Y': Y_target.cpu(), 'weights': weights.cpu(), 'Z': Z.cpu(),
                        'n': x.shape[0]}
        sum_count, features_names, do_sum, coeffs = return_feature_names(job)
        # if do_sum:
        #     chunk_l,chunk_r = torch.chunk(x,dim=1,chunks=2)
        #     data_l = cumsum_thingy_2(sum_count, chunk_l)
        #     data_r = cumsum_thingy_2(sum_count, chunk_r)
        #     data = torch.cat([data_l,data_r],dim=1)
        # else:
        #     data=x
        # cols =[f'{g} (l)' for g in features_names]+[f'{g} (r)' for g in features_names]
        # df = pd.DataFrame(data.numpy(), columns=cols)
        chunk_l, chunk_r = torch.chunk(x, dim=1, chunks=2)
        if do_sum:
            chunk_l = cumsum_thingy_2(sum_count, chunk_l)
            chunk_r = cumsum_thingy_2(sum_count, chunk_r)
        data = chunk_r+chunk_l #chunk_l+chunk_r
        # data = chunk_r-chunk_l #chunk_l+chunk_r
        print(features_names)
        df = pd.DataFrame(data.numpy(), columns=features_names )
        df['fold'] = f
        return cooking_dict,df
    else:
        pgp = model_string=='SGD_krr_pgp'
        print("ind_points shape = ", ind_points.shape)
        x_ind_l,x_ind_r  = torch.chunk(ind_points,dim=1,chunks=2)
        print("x_ind_l shape = ", x_ind_l.shape)
        print("x_ind_r shape = ", x_ind_r.shape)
        print("d=x_ind_l.shape[1]= ",x_ind_l.shape[1])
        inner_kernel=RBF_multiple_ls(d=x_ind_l.shape[1])
        #inner_kernel=LinearKernel()
        inner_kernel._set_lengthscale(ls)
        inner_kernel=inner_kernel.to('cuda:0')
        alpha=alpha.to('cuda:0')
        ind_points = ind_points.to('cuda:0')

        print("alpha shape = ", alpha.shape)
        print("centers shape = ", ind_points.shape)
        weights = torch.matmul(ind_points.t(), alpha)
        print("function weights for the transformed features (d^2) = ", weights, weights.shape)
        items=ind_points
        dim = items.shape[1]
        print("dim = ", dim)
        left = items[:,0:int(dim/2)]
        right = items[:,int(dim/2):dim]
        #print("left, right = ", left.shape, right.shape)
        final_weights = torch.matmul(transformed_features(left, right, int(dim/2)).t(), alpha)

        print("final weights = ", final_weights, final_weights.shape)


        ##evaluate kernel logistic regression value
        print("c.left_tr = ",c.left_tr.shape)
        print("c.left_test = ",c.left_test.shape)
        print("c.left_val = ",c.left_val.shape)
        data_total_left = np.concatenate((c.left_tr,c.left_val),axis=0)
        data_total_right = np.concatenate((c.right_tr,c.right_val),axis=0)
        label_total = np.concatenate((c.y_block_tr_0, c.y_block_val_0.reshape(-1, 1)), axis=0)


        #x,x_prime = torch.from_numpy(c.left_val).float(),torch.from_numpy(c.right_val).float()
        ##y = torch.from_numpy((c.y_val > 0) * 1.0).unsqueeze(-1)
        #y = torch.from_numpy(c.y_val).float().unsqueeze(-1)
        #if x.shape[0]<100:
        x, x_prime = torch.from_numpy(data_total_left).float(), torch.from_numpy(data_total_right).float()
        #y = torch.from_numpy((c.y_tr > 0) * 1.0).unsqueeze(-1)
        y = torch.from_numpy(label_total).float().unsqueeze(-1)
        
        print("data_total_left = ",data_total_left.shape)
        print("x unique shape = ",np.unique(x, axis = 0).shape)
        print("x prime.shape = ",x_prime.shape)
        print("x prime unique shape = ",np.unique(x_prime, axis = 0).shape)
        #shap_l,shap_r = x[0:num_matches, :], x_prime[0:num_matches, :]
        shap_l,shap_r = x[0:, :], x_prime[0:, :] 

        k = inner_kernel
        np.savetxt(f"alpha_{block}.txt", alpha.cpu().numpy())

        #xltrain = x_ind_l
        xltrain = shap_l
        #xrtrain = x_ind_r
        xrtrain = shap_r
        
        np.savetxt("xltrain.txt", xltrain)
        np.savetxt("xrtrain.txt", xrtrain)
        
        testdata = np.concatenate((shap_l, shap_r), axis = 1)
        print("testdata", testdata.shape)
        np.savetxt(f"test_pairs_fold_{fold}.txt", testdata)
        
        
        ps = pref_shap(model=model, alpha=alpha, k=inner_kernel, X_l=x_ind_l, X_r=x_ind_r, X=c.S, max_S=2500,
                       rff_mode=False, eps=1e-3, cg_max_its=10, lamb=1e-3, max_inv_row=0, cg_bs=25, post_method=post_method,
                       interventional=interventional, device='cuda:0')
       
        shap_l,shap_r = xltrain, xrtrain
        y = y[0:num_matches]
        
        function_val_0 = []
        function_val_1 = []
        dim = shap_l.shape[1]
        print("shap_l shape = ", shap_l.shape)        
        print("xltrain = ", xltrain.shape)
        print("xrtrain = ", xrtrain.shape)
        print("shap_l = ", shap_l.shape)
        print("shap_r = ", shap_r.shape)
        Y_target, weights, Z, y_pred =  ps.fit(shap_l,shap_r,pgp=pgp)

        cooking_dict = {'Y':Y_target.cpu(), 'weights':weights.cpu(),'Z':Z.cpu(), 'klr':y_pred.cpu(),
                        'n':shap_l.shape[0]}
	
        shap_lr = np.column_stack((shap_l, shap_r))
        #print("shap_l, shap_r = ", shap_lr)
        
        print("y shape = ", y.shape)
        
        #winners = y * shap_r + (1 - y) * shap_l
        #loosers = (1 - y) * shap_r + y * shap_l
        winners =  shap_l 
        loosers =  shap_r
        print("winners shape = ", winners.shape)
        print("return_feature_names output:", return_feature_names(job))
        sum_count, features_names, do_sum, coeffs = return_feature_names(job)
        diff_abs = winners - loosers
        data = cumsum_thingy_2(sum_count, diff_abs)
        print("data shape = ", data)
        df = pd.DataFrame(data.numpy(), columns=features_names)
        df['fold'] = f
        return cooking_dict,df


def get_shapley_vals_2(cooking_dict,job,post_method,block,fold,m='SGD_krr'):
    sum_count,features_names,do_sum,coeffs=return_feature_names(job)
    if m=='SGD_base':
        pass
        # features_names =[f'{g} (l)' for g in features_names]+[f'{g} (r)' for g in features_names]
        # features_names=features_names
        # features_names = features_names + [f'{g}' for g in features_names]
    print(features_names)
    outputs = construct_values(cooking_dict['Y'],cooking_dict['Z'],
                              cooking_dict['weights'],cooking_dict['klr'],coeffs,post_method
                              )
    big_plot=[]
    for key,output in outputs.items():

        if m == 'SGD_base':
            o_u, o_d = torch.chunk(output, dim=0, chunks=2)
            if do_sum:
                p_output_u = cumsum_thingy(sum_count, o_u)
                p_output_d = cumsum_thingy(sum_count, o_d)
                # p_output = p_output_d+p_output_u
                p_output = p_output_d-p_output_u
                # p_output = torch.cat([p_output_u, p_output_d], dim=0)
            else:
                # p_output = o_d-o_u
                p_output = o_d+o_u
        else:
            if do_sum:
                p_output = cumsum_thingy(sum_count,output)
            else:
                p_output = output
        tmp = p_output.cpu().numpy().flatten()
        tst= np.arange(1,len(features_names)+1).repeat(cooking_dict['n'])
        plot =  pd.DataFrame(np.stack([tst,tmp,np.ones_like(tst)*key],axis=1),columns=['d','shapley_vals','lambda'])
        big_plot.append(plot)
    plot = pd.concat(big_plot,axis=0).reset_index()

    clf = Ridge(1e-10)
    Z = cooking_dict['Z']
    Y_target = cooking_dict['Y']
    print("Z = ", Z)
    print("Y_target = ", Y_target.shape)
    print("Y_target = ", Y_target[-1,:])
    print("Y_target = ", Y_target)
    func_val = Y_target[-1,:]-Y_target[0,:]
    np.savetxt(f"Y_target_{block}.txt", func_val.t())
    weights = cooking_dict['weights']
    print("weights = ", weights.cpu().numpy().reshape(-1))

    clf.fit(Z, Y_target, weights.cpu().numpy().reshape(-1))
    full_shapley_values_ = np.concatenate([clf.intercept_.reshape(-1, 1), clf.coef_], axis=1)
    SHAP_LM = clf

    print("prefshap values = ", (clf.coef_).shape, clf.coef_)
    
    left_tr = np.loadtxt(f"original_train_left_items_{fold}.txt")
    left_val = np.loadtxt(f"original_val_left_items_{fold}.txt")
    right_tr = np.loadtxt(f"original_train_right_items_{fold}.txt")
    right_val = np.loadtxt(f"original_val_right_items_{fold}.txt")

    data_left = np.concatenate((left_tr,left_val),axis=0)
    data_right = np.concatenate((right_tr,right_val),axis=0)

    data_total = np.concatenate((data_left,data_right),axis=0)

    data_left_stds = data_left.std(axis = 0)
    data_right_stds = data_right.std(axis = 0)
    data_total_stds = data_total.std(axis = 0)

    print("clf.coef_ shape = ", clf.coef_.shape)

    for f in range(int(sys.argv[1])):
        (clf.coef_)[:,f] = (clf.coef_)[:,f] * data_total_stds[f]

    np.savetxt("prefshap_block0.txt",clf.coef_[:,0]+clf.coef_[:,1])
    np.savetxt("prefshap_block1.txt",clf.coef_[:,2]+clf.coef_[:,3])
    np.savetxt(f"prefshap_block_{block}_fold_{fold}.txt",clf.coef_)
    
    return plot,features_names,clf.coef_,func_val.t()



if __name__ == '__main__':
    # d_imp = 2
    # d=10
    # palette =['r']*d_imp+ ['g']*(d-d_imp)
    # for job in ['chameleon_wl','pokemon_wl']:
    # for job in ['chameleon_wl']:
    # for job in ['alan_data_5000_1000_10_10','toy_data_5000_10_2']:
    #for job in ['toy_data_5000_4_2']:
    # for job in ['chameleon_wl','pokemon_wl','alan_data_5000_100']:
    #for job in ['synthetic_data_sigmoid_4950_0']:
    #for job in [f'synthetic_data_{int(sys.argv[2])}_0']:
    job_name = f'synthetic_data_{int(sys.argv[2])}_0'
    block = int(sys.argv[3])
    #block = [1, 2, 3, 4]

    for job in [job_name]:
        print("job = ", job)
    # d = [5, 3]
    # for job in [f'hard_data_10000_1000_{d[0]}_{d[1]}']:
    # for job in ['pokemon_wl']:
    #     for f in [0,1,2,3,4]:
    #     for m in ['SGD_krr','SGD_krr_pgp']:
        for m in ['SGD_krr']:
            f_val = [0,1]
            prefshap = np.zeros((1,int(sys.argv[1])))
            #for f in [0,1]:
            hello = 0
            shap_iter = 0
            prefshap_list = []
            func_val_list = []
            hello = 0
            

            for f in range(5):  # Loop over folds
                
                interventional = False
                model = 'SGD_krr'
                fold = f
                block = block
                train_params = {
                    'dataset': job,
                    'fold': fold,
                    'epochs': 100,
                    'patience': 10,
                    'model_string': model,
                    'bs': 512,
                    'double_up': False,
                    'm_factor': 10.0,
                    'seed': 42,
                    'folds': 5,
                    'block': block
                }
                
                train_params['model_string']=model
                train_params['fold']=fold
                
                res_name = f'{interventional}_{job}_{model}'
                if not os.path.exists(res_name):
                    os.makedirs(res_name)

                abs_data_container = []
                print("hello =", hello)
                cooking_dict, abs_data = get_shapley_vals(
                    job=job,
                    model_string=model,
                    fold=fold,
                    block=block,
                    train_params=train_params,
                    num_matches=int(sys.argv[2]),
                    post_method='OLS',
                    interventional=interventional,
                    hello=hello
                )
                abs_data_container.append(abs_data)
                torch.save(cooking_dict, f'{res_name}/cooking_dict_{fold}.pt')
                hello += 1

                big_df = pd.concat(abs_data_container, axis=0).reset_index(drop=True)
                big_df.to_csv(f'{res_name}/data_folds.csv')

                for post_method in ['ridge']:
                    cooking_dict = torch.load(f'{res_name}/cooking_dict_{fold}.pt')
                    data, features_names, prefshap2, func_val = get_shapley_vals_2(
                        cooking_dict, job, post_method, block, fold, model
                    )

                    prefshap_size = prefshap2.shape[0]
                    prefshap_avg = np.mean(np.abs(prefshap2), axis=0)
                    prefshap_avg = np.atleast_1d(prefshap_avg)  # Ensure 1D array
                    np.savetxt(f"prefshap_global_original_fn_fold_{f}.txt", prefshap_avg)

                    func_val_np = func_val.cpu().numpy()
                    func_val_avg = np.mean(func_val_np, axis=0)
                    func_val_avg = np.atleast_1d(func_val_avg)  # Ensure 1D array

                    prefshap_list.append(prefshap_avg)
                    func_val_list.append(func_val_avg)

                    data['fold'] = f
                    data['d'] = data['d'].apply(lambda x: features_names[int(x-1)])

                    plot_df_name = f'{res_name}/{res_name}_{post_method}.csv'
                    if os.path.exists(plot_df_name):
                        data.to_csv(plot_df_name, mode='a', header=False, index=False)
                    else:
                        data.to_csv(plot_df_name, index=False)

            # === Final Aggregation ===

            prefshap_stack = np.stack(prefshap_list)  # shape: [n_folds, n_features]
            func_val_stack = np.stack(func_val_list)  # shape: [n_folds, n_blocks]

            prefshap_avg_final = np.mean(prefshap_stack, axis=0)
            prefshap_std_final = np.std(prefshap_stack, axis=0)

            func_val_avg_final = np.mean(func_val_stack, axis=0)
            func_val_std_final = np.std(func_val_stack, axis=0)

            print("Final PrefShap average:", prefshap_avg_final)
            print("Final PrefShap std:", prefshap_std_final)
            print("Final function value average:", func_val_avg_final)
            print("Final function value std:", func_val_std_final)

            # Save
            np.savetxt(f"prefshap_global_block_{block}_mean.txt", prefshap_avg_final)
            np.savetxt(f"prefshap_global_block_{block}_std.txt", prefshap_std_final)
            np.savetxt("function_value_avg.txt", func_val_avg_final)
            np.savetxt("function_value_std.txt", func_val_std_final)

            # Optional: Plotting
            #import matplotlib.pyplot as plt
            x = np.arange(len(prefshap_avg_final))
            plt.figure(figsize=(10, 5))
            plt.bar(x, prefshap_avg_final, yerr=prefshap_std_final, capsize=5)
            plt.xlabel("Feature Index")
            plt.ylabel("Global Shapley Importance")
            plt.title("Average PrefShap Importance with Std Dev (All Folds)")
            plt.grid(True)
            plt.tight_layout()
            plt.savefig("prefshap_importance_barplot.png")
            plt.show()
