import torch
import torch.nn as nn
import torch.nn.functional as F

from .odst import ODST
from .odst import GAM_NODE_ODST

import itertools

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import copy

from sklearn.preprocessing import MinMaxScaler , StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import LabelEncoder

from itertools import combinations

from models import nn_utils
import pickle


from sklearn.preprocessing import QuantileTransformer
qt = QuantileTransformer(n_quantiles=1000, random_state=0,output_distribution='uniform')


class DenseBlock(nn.Sequential):
    def __init__(self, features_list, layers , device, training=True, flatten_output = True, input_dropout=0.0, **kwargs):
        
        self.training = training
        self.features_list = features_list
        self.device = device 

        super().__init__(*layers)
        
        self.input_dropout = input_dropout
        
        self.layers = layers
        
        self.drop_out = False


    def forward(self, x,inital=False):
        #initial_features = x.shape[-1]
        
        count = 0
        output_list = torch.tensor([]).to(self.device)
        
        for layer in self.layers:
            if type(self.features_list[count]) == int:
                layer_inp = x[:,self.features_list[count]].reshape(-1,1)
            else:
                layer_inp = x[:,self.features_list[count]]
                
            if inital:
                layer.initialize(layer_inp)
                
            h = layer(layer_inp,self.training)
            output_list = torch.cat([output_list,h],axis=1)
            count += 1
        #output_list = output_list[:,self.num_features:]
  
        return output_list
    
    def model_save_id_constants(self):
        for layer in self.layers:
            layer.save_id_constants()
                       
    
    
def gen_odst(num_features,input_dim,layer_dim,device,num_multiclass,max_order,choice_function,bin_function,features_list="all",monotone_list=[]):
    
    if features_list =="all" :
        if max_order == 1:
            features_list = [i for i in range(num_features)]
        elif max_order == 2:
            features_list = [i for i in range(num_features)]
            features_list.extend(list( itertools.combinations(features_list,2) ))    
        elif max_order == 3:
            features_list = [i for i in range(10)]
            features_list_copy = copy.deepcopy(features_list)
            features_list.extend(list( itertools.combinations(features_list_copy,2) ))
            features_list.extend(list( itertools.combinations(features_list_copy,3) )) 
        elif max_order == 4:
            features_list = [i for i in range(10)]
            features_list_copy = copy.deepcopy(features_list)
            features_list.extend(list( itertools.combinations(features_list_copy,2) ))
            features_list.extend(list( itertools.combinations(features_list_copy,3) )) 
            features_list.extend(list( itertools.combinations(features_list_copy,4) ))
        else:
            print("Order error")            
        
    if num_multiclass <= 2:
        tree_dim = 1
    elif num_multiclass >= 3:
        tree_dim = num_multiclass
        
        
    layers = []
    for i in range(len(features_list)):
                            
        if type(features_list[i]) == int:
            input_dim = 1
            if len(monotone_list) != 0:
                if monotone_list[i] != False:
                    monotone = monotone_list[i]
                else:
                    monotone = False
            else:
                monotone = False
        else:
            input_dim = len(features_list[i])
            monotone = False

                                         
        oddt = ODST(in_features=input_dim, 
                    num_trees=layer_dim,
                    tree_dim = tree_dim, 
                    device= device, 
                    flatten_output=True,
                    choice_function=choice_function ,
                    bin_function=bin_function,
                    monotone=monotone)

        layers.append(oddt) 
        
    return features_list, layers


## Robust Interpretability
def cal_UoC(all_data_loader, model_path, tree_num, max_order, in_features,device,regression,normalize = False,interaction_list = [],num_seed=10):

    device = device
    choice_function=nn_utils.entmax15
    bin_function=nn_utils.entmoid15
        
    input_dim = in_features
    num_features = input_dim

    layer_dim = tree_num    
    multiclass = 2
    
    
    if max_order == 1:
        num_component = in_features
    else:
        num_component = in_features + (in_features*(in_features-1))/2
    
    
    for data__ in all_data_loader:
        data__x = data__[:,:in_features].to(device) 
                        
    if regression:
        multiclass = 2

    if multiclass == 2:
        num_multiclass = 1
    else:
        num_multiclass = multiclass    
        
    var_features_list = []
    abs_features_list = []
    
    if len(interaction_list) == 0:
        component_list = [i for i in range(0,in_features)]
        
        if max_order == 2:
            component_list.extend(list( combinations(component_list,2)))
    else:
        component_list = interaction_list
        
    
    model_list = []
    
    for w in range(0,num_seed):
        
        if len(interaction_list) == 0:
            features_list_cs, layers = gen_odst(num_features,input_dim,layer_dim,device,num_multiclass,max_order,choice_function, bin_function)
        else:
            features_list_cs, layers = gen_odst(num_features,input_dim,layer_dim,device,num_multiclass,max_order,choice_function, bin_function,features_list=interaction_list)
        model = nn.Sequential(
            DenseBlock(features_list_cs,layers,device),
            nn_utils.Lambda(lambda x:  x.mean(dim=1)),
        )
        model = model.to(device) 
        
        ##### Load model #####

        load_model_state =  torch.load(model_path + f"-{w}") 
                    
        model.load_state_dict(load_model_state)     
        model[0].training = False 
        
        model_list.append(model)
    

    for l in range(0,len(component_list)):
        feature_j = component_list[l]
        
        print(f"Componet : {feature_j} processing...")
        
        all_output_trial = torch.tensor([])
        for w in range(0,num_seed):           

            model = model_list[w]
            
            if len( data__x[:,feature_j].shape ) ==1 :
                output_comp = model[0][l](data__x[:,feature_j].reshape(-1,1).float(),False).mean(dim=1).detach().cpu()/num_component
            else:
                
                if True:
                    input_new_data = torch.tensor([]).to(device)

                    for k in range(0,data__x.shape[0]):
                        
                        local_data = data__x[:,feature_j]
                        
                        local_data[:,0] = local_data[k,0]
                        
                        input_new_data = torch.concat([input_new_data,local_data])
                        
                    output_comp = model[0][l](input_new_data.float(),False).mean(dim=1).detach().cpu()/num_component
                else:
                    output_comp = model[0][l](data__x[:,feature_j].float(),False).mean(dim=1).detach().cpu()
                
            all_output_trial = torch.concat([all_output_trial,output_comp.reshape(1,-1)])
            
        if normalize == True:
                    
            for n in range(0,all_output_trial.shape[1]):
                if np.sqrt(np.sum(np.array(all_output_trial[:,n])**2))   != 0:
                    all_output_trial[:,n] = all_output_trial[:,n] /np.sqrt(np.sum(np.array(all_output_trial[:,n])**2))

                        
        var_sum =0
        all_output_trial  = np.array( all_output_trial )
        for c in range(0,all_output_trial.shape[1]):
                
            var_sum += np.var(all_output_trial[:,c])
                
        var_sum /= all_output_trial.shape[1]
        
        abs_sum = 0
        for c in range(0,all_output_trial.shape[1]):
                
            abs_sum += np.mean( np.abs( all_output_trial[:,c] - np.mean(all_output_trial[:,c]) ) )
                
        abs_sum /= all_output_trial.shape[1]
                
        var_features_list.append(var_sum)
        abs_features_list.append(abs_sum)
        
    return var_features_list,abs_features_list



############ Figure Main shape function ############  

def make_fig(data_x,data_y, regression,max_order,tree_num,columns_list,model_path,device,cs=False,fig=True,init_test=False,init_random_seed=0,uniform_transform=False):

    if uniform_transform == True:
        data_x = 2*qt.fit_transform(data_x) -1
        
    train_x,test_x_,train__y,test_y_ = train_test_split(data_x,data_y, test_size=0.3, random_state=0)
    in_features = train_x.shape[1]
    
    if max_order == 1:
        num_component = in_features
    else:
        num_component = in_features + (in_features*(in_features-1))/2
        
    for w in range(0,10):
    
        choice_function=nn_utils.entmax15
        bin_function=nn_utils.entmoid15
            
        input_dim = in_features
        num_features = input_dim

        layer_dim = tree_num    
        multiclass = 2
        
        
                            
        if regression:
            multiclass = 2

        if multiclass == 2:
            num_multiclass = 1
        else:
            num_multiclass = multiclass    
            
        if cs:
            feature_list_path = f"{model_path}-{w}_features_list"

            with open(feature_list_path, 'rb') as fp:
                features_list_cs = pickle.load(fp)
                
            
            print(f"{w}-th component set : {features_list_cs}")
            
            _,layers = gen_odst_cs(features_list_cs,layer_dim,num_multiclass,device,choice_function,bin_function)
            
        else:
            features_list_cs, layers = gen_odst(num_features,input_dim,layer_dim,device,num_multiclass,max_order,choice_function, bin_function)
        
        
        model = nn.Sequential(
                DenseBlock(features_list_cs,layers,device),
                nn_utils.Lambda(lambda x:  x.mean(dim=1)))
        

        ##### Load model #####

        load_model_state =  torch.load(f"{model_path}-{w}")
        model.load_state_dict(load_model_state) 
        
        model = model.to(device)
        model.eval()
        
        if init_test == True:
            w = init_random_seed
            
        train_x,test_x_,train__y,test_y_ = train_test_split(data_x,data_y, test_size=0.3, random_state=w)
        val_x,test_x,val_y,test_y = train_test_split(test_x_,test_y_, test_size=0.66, random_state=0)
        
        scaler = StandardScaler()
        scaler.fit(train_x)
        train_x = scaler.transform(train_x)
        test_x = scaler.transform(test_x)
        val_x = scaler.transform(val_x)
        
        if regression == True:
            test_y = (test_y - torch.mean(train__y))/torch.std(train__y)

        
        test_data = torch.cat([torch.tensor(test_x),test_y],dim=1)

        test_dataloader = DataLoader(test_data, batch_size=len(test_data), shuffle=True)
        
        
        if regression == True:
            
            test_loss = 0
            for test__ in test_dataloader:
                test__x,test__y = test__[:,:in_features].to(device) , test__[:,in_features].to(device)   
                
                #print(test__x,test__x.shape)

                test_loss +=  torch.sum( (model(test__x.float()).flatten() - test__y.flatten())**2 )
                
            test_rmse = torch.sqrt( test_loss.cpu().detach()/len(test_x) )
            print(f"state {w} ||  test rmse : {test_rmse*torch.std(train__y)}")
            
        else:
            all_test__y = torch.tensor([])
            all_test__output = torch.tensor([])
            for test__ in test_dataloader:
                test__x,test__y = test__[:,:in_features].to(device) , test__[:,in_features].to(device)
                all_test__y = torch.concat([all_test__y,test__y.detach().cpu()])
                all_test__output = torch.concat([all_test__output,(model(test__x.float()).reshape(-1,1)).detach().cpu()]) 
                
            test_measure = roc_auc_score(all_test__y,all_test__output)  
            print(f"state {w} || test auc : {test_measure}")
            
        
        if fig == True:
            main_list = []
            
            for f in range(0,len(features_list_cs)):
                
                if type( features_list_cs[f] ) == int:
                    
                    main_list.append(features_list_cs[f])
            
            
            if cs:
                max_feature = np.min([len(main_list),11])
            else:
                max_feature = np.min([in_features,11])
            
            
            f, axes = plt.subplots(1, max_feature, sharex=False, sharey=False)
            f.set_size_inches((25, 2))  
            
            f.text(0.09, 0.5, "Output Contribution", va='center', rotation='vertical') 
            
            y_max_list = []
            y_min_list = []
            for j in range(0,max_feature):
                y_max_list.append( np.max(np.array(  model[0][j](test__x[:,j].reshape(-1,1).float(),False).mean(dim=1).detach().cpu() )/num_component ) )
                y_min_list.append( np.min(np.array(  model[0][j](test__x[:,j].reshape(-1,1).float(),False).mean(dim=1).detach().cpu() )/num_component ) )
            
            
            y_max = np.max(y_max_list) + np.max(y_max_list)/10
            y_min = np.min(y_min_list) + np.min(y_min_list)/10
            
            sns.set(font_scale=1.0)
            for i in range(0,max_feature):
                
                axes[i].set(xlabel = columns_list[i])
            
                #axes[i].set(ylabel = "Output contribution")
                
                axes[i].set_ylim([y_min,y_max])
            
                scatter_x = np.array( test__x[:,i].detach().cpu() )

                scatter_y = np.array(  model[0][i](test__x[:,i].reshape(-1,1).float(),False).mean(dim=1).detach().cpu()/num_component  ) 
                
                #scatter = pd.DataFrame([scatter_x,scatter_y],index = [f"variable {i}",f"f {i}"]).T
                sns.lineplot(  x=scatter_x.flatten(),y=scatter_y.flatten(),ax=axes[i],color = "blue")
                sns.set(font_scale=1.0)
            f.show()