import numpy as np
import csv, os, string, pickle
import training
import torch
import json
from datetime import datetime
import time
from datetime import timedelta
import pandas as pd
import numpy as np
import scipy.stats as st
class Arguments():
    def __init__(self):
        self.data_path = None
        self.results_folder = None
        self.temp_folder = None
        self.docs_folder = None
        self.pretrained_vectors = False
        self.prev_model_file = None

        ## Learning parameters
        self.opt_type = "rms" 
        self.lr = 0.0001 
        self.nepochs = 50 
        self.dropout_p = 0.5
        self.nfolds = 1
        self.resample = 0
        self.w_decay=0
        

        ## Network structure params
        self.model_id=None
        self.n_layers = 3
        self.embed_dim = 300
        self.c_dim = 100
        self.h_dim = 256

        self.representation_dim=4096
        
        #predict n words parameter
        self.t=1

        ## Synthetic experiment arguments
        self.synth_args = None

class SupArgs():
    def __init__(self,best_C,best_solver,accu,cv_duration):
        self.best_C = best_C
        self.best_solver = best_solver
        self.accu = accu
        self.cv_duration = cv_duration
 

def train_model(args):
    model = training.train_model(model_id=args.model_id,
                                            data_path=args.data_path,
                                            results_folder=args.results_folder,
                                            t=args.t,
                                            c_dim=args.c_dim, 
                                            h_dim=args.h_dim,
                                            representation_dim=args.representation_dim,
                                            nepochs=args.nepochs, 
                                            lr=args.lr, 
                                            embed_dim=args.embed_dim, 
                                            opt_type=args.opt_type,
                                            dropout_p=args.dropout_p,
                                            n_layers=args.n_layers,
                                            nfolds=args.nfolds,
                                            resample=args.resample,
                                            w_decay=args.w_decay,
                                            pretrained_vectors=args.pretrained_vectors, 
                                            temp_model_folder=args.temp_folder, 
                                            prev_model_file=args.prev_model_file,
                                            presampled_docs_file=args.docs_folder,
                                            synthetic_args=args.synth_args)


    return(model)

'''
Helper functions
'''
def saveClassifierDataframe(df,folder_name,file_name):
    file_path=os.path.join(folder_name, file_name)
    with open(file_path, 'wb') as f:
        pickle.dump(df, f)
    print("Dataframe saved...")
    return file_path


def saveModel(model,path):
    torch.save(model,path)
    print("model save at path:",path)


'''
Experiment stats and some results is updated to the file: 'models/meta_model.json' - for first time run, should make a new empty file named as such.
'''
def saveModelStats(args,unsup_id,unsup_duration=None,sup_duration=None,acc=None,best_c=None,
                    best_solver=None,cv_duration=None,df_filepath=None,best_model_file_name=None):
    folder_name = 'models'
    if(not os.path.isdir(folder_name)):
        os.mkdir(folder_name)
    file_name='meta_model.json'
    cur={}
    cur['model_id']=args.model_id
    cur["time saved"]=datetime.now().strftime("%d/%m/%Y %H:%M:%S")
    cur["layers"]=args.n_layers 
    cur["embed_dim"]=args.embed_dim ## embedding layer dimension
    cur["h_dim"] = args.h_dim     ## hidden layer dimension 512
    cur["representation_dim"] = args.representation_dim # representation dimension
    cur["nepochs"]=args.nepochs   ## number of epochs
    cur["lr"]= args.lr       ## learning rate 0.0002
    cur["dropout_p"]=args.dropout_p   ## dropout
    cur['resample'] = args.resample    ## frequency of resampling data
    cur["opt_type"]=  args.opt_type   ## optimizer
    cur["t"]= args.t   #num words in labels
    cur["w_decay"]= args.w_decay #w_decay


    if unsup_duration:
        cur["unsup_duration"]=unsup_duration
    if sup_duration:
        cur["sup_duration"]=sup_duration
    if acc:
        cur['accuracy']=acc
    if best_c:
        cur['best_c']=best_c
    if best_solver:
        cur['best_solver']=best_solver
    if cv_duration:
        cur['cv_duration']=cv_duration
    if df_filepath:
        cur['df_filepath']=df_filepath
    if best_model_file_name:
        cur['best_model_file_name']=best_model_file_name
    #cur["n_samples"]=n_samples

    
    #updating 
    with open(os.path.join(folder_name,file_name), 'r') as openfile:
        obj = json.load(openfile)
    obj[unsup_id]=cur
    updated_obj = json.dumps(obj, indent = 4)
    
    with open(os.path.join(folder_name,file_name), "w") as outfile:
        outfile.write(updated_obj)

'''
Generate dataframes (supervised training sample vs accuracy + CI interval) for different embedding types

Available embedding types: 'model','BOW','word2vec'

Return df
'''
def generateLCDataframe(predictmodel=None,embedding_type='model',contrastive=False):

    #Run Linear Classifier
    print("Running supervised learning...")
    sup_starttime=time.time()

    data=[]
    N=10 #reptition for CI
    for n_samples in [100,600,1100,1600,2100,2600,3100,3600,4000]:
        accu_list=[]
        best_c_list=[]
        best_solver_list=[]
        cv_duration_list=[]
        for n in range(N):
            print('Running reptition number: ',n,' with number of data used: ',n_samples)
            best_c,best_solver,accu,cv_duration=training.train_classifier(model=predictmodel,n_samples=n_samples,embedding_type=embedding_type,contrastive=contrastive)
            accu_list.append(accu)
            best_c_list.append(best_c)
            best_solver_list.append(best_solver)
            cv_duration_list.append(cv_duration)
        mean=np.mean(accu_list)
        lb,ub=st.t.interval(alpha=0.95, df=len(accu_list)-1, loc=mean, scale=st.sem(accu_list)) 
        data.append(
            {
                'n_samples':n_samples,
                'mean_accuracy':mean,
                'lb':lb,
                'ub':ub,
                'best_c_list':best_c_list,
                'best_solver_list':best_solver_list,
                'cv_duration_list':cv_duration_list
            }
        )
    df = pd.DataFrame(data)
    sup_endtime=time.time()
    sup_duration=str(timedelta(seconds=sup_endtime - sup_starttime))
    return df,sup_duration
    



'''
run_type: "one"|"df"
'''
def runExperiment_model(run_type="one",skip_unsup=False,unsup_id=None,model_id=None,n_layers=None,
                        representation_dim=None,embed_dim=5000,h_dim=None,nepochs=150,opt_type='amsgrad',w_decay=0,
                        resample=2,dropout_p=0,lr=0.0002,prev_model_file=None,n_samples=4000):

    unsup_id=unsup_id 
    print(unsup_id)
    ## Fit the model
    args = Arguments()
    args.temp_folder = "models" ## Temporary folder to hold intermediate models
    if(not os.path.isdir(args.temp_folder)):
        os.mkdir(args.temp_folder)

    args.data_path = "data" ## Folder for experiment data
    if (not os.path.isdir(args.data_path)):
        os.mkdir(args.data_path)

    ## Folder to hold results
    args.results_folder =  "results/" +str(unsup_id)
    if(not os.path.isdir(args.results_folder)):
        os.mkdir(args.results_folder) 

    ## NN parameters
    args.model_id=model_id  #word2vec
    args.n_layers = n_layers  #3
    #args.c_dim = 512        ## final layer dimension
    args.embed_dim = embed_dim    ## embedding layer dimension
    args.h_dim = h_dim        ## hidden layer dimension 512
    args.representation_dim=representation_dim #400
    


    args.nepochs = nepochs     ## number of epochs 
    args.lr = lr       ## learning rate 0.0002
    args.dropout_p = dropout_p   ## dropout
    args.resample = resample      ## frequency of resampling data
    args.opt_type = opt_type   ## optimizer
    args.t=4      #num word    s in labels
    args.w_decay=w_decay #w_decay0.01

    
    ##Run Unsupervised
    print('Rerun unsupervised learning....')
    unsup_starttime=time.time()

    model=None
    if not skip_unsup:
        model = train_model(args)
        saveModel(model,'models/'+str(unsup_id)+'_model.pt')
    else:
        if prev_model_file:
            model=torch.load(prev_model_file)
        else:
            print("No model exception. You should pass a previous model file if you want to skip unsupervised learning")
    unsup_endtime=time.time()
    unsup_duration=str(timedelta(seconds=unsup_endtime - unsup_starttime))

    

    #Run Linear Classifier
    print("Running supervised learning...")

    def generateDfRun(model,model_id):
        if model_id=="contrastive":
            contrastive=True
        else:
            contrastive=False
        df,duration = generateLCDataframe(model,contrastive=contrastive)
        
        
    
        print('Saving dataframe and model stats...')

        file_path=saveClassifierDataframe(df,args.results_folder,unsup_id+'_accuracy_new_df.pkl')

        saveModelStats(args,unsup_id,unsup_duration=unsup_duration,sup_duration=duration,df_filepath=file_path,best_model_file_name=best_model_file_name)
        print("Stats Saved") #in meta_model.json


    def quickTestRun(model,n_samples,model_id):
        if model_id=="contrastive":
            contrastive=True
        else:
            contrastive=False
        best_c,best_solver,accu,cv_duration=training.train_classifier(model=model,n_samples=n_samples,embedding_type='model',contrastive=contrastive)
        saveModelStats(args,unsup_id,acc=accu,unsup_duration=unsup_duration,best_c=best_c,
                       best_solver=best_solver,cv_duration=cv_duration)
        return accu

    if run_type=="one":
        print("running one time on 4000 sample")
        quickTestRun(model,n_samples,model_id)
    elif run_type=="df":
        print("running df")
        generateDfRun(model,model_id)
    else:
        print("Wrong run_type specified")



def runExperiment_baseline(run_type="one",unsup_id=None,embedding_type='BOW',n_samples=4000):
    unsup_id=unsup_id
    print(unsup_id)
    args = Arguments()
    args.temp_folder = "models" ## Temporary folder to hold intermediate models
    if(not os.path.isdir(args.temp_folder)):
        os.mkdir(args.temp_folder)

    args.data_path = "data" ## Folder for experiment data
    if (not os.path.isdir(args.data_path)):
        os.mkdir(args.data_path)

    ## Folder to hold results
    args.results_folder =  "results/" +str(unsup_id)
    if(not os.path.isdir(args.results_folder)):
        os.mkdir(args.results_folder) 

    def generateDfRun():
        #generate data frame
        df,duration=generateLCDataframe(embedding_type=embedding_type)
        print('supervised duaration: ',duration)
        saveClassifierDataframe(df,args.results_folder,unsup_id+'_accuracydf.pkl')
    
    def quickTestRun(n_samples):
        _,_,accu,_=training.train_classifier(n_samples=n_samples,embedding_type=embedding_type)
        return accu
    
    if run_type=='one':
        quickTestRun(n_samples)
    else:
        generateDfRun()


if __name__=='__main__':
    '''
    REMEMBER TO CHANGE BATCH SIZE ACROSS RUNS in Training.py file 
    '''

    '''
    Base RBL
    '''

    runExperiment_model(run_type="one",skip_unsup=False,unsup_id='sample_run',model_id='base',n_layers=3,
                            representation_dim=None,embed_dim=5000,h_dim=4096,nepochs=150,opt_type='amsgrad',w_decay=0.01,
                            resample=2,dropout_p=0,lr=0.0002,prev_model_file=None,n_samples=4000) 



