import torch
import sys
sys.path.append("NeuralExt")
from model.model import ExtensionNet
from model.reinforce import ReinforceNet
from model.straight_through import STNet
import argparse


def fetch_nsfe_args(): 
    parser = argparse.ArgumentParser(description='ExtensionNet trainer')

    #training params
    parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ')
    parser.add_argument('--lr_decay_step_size', default=40, type=int, help='how many epochs between lr steps')
    parser.add_argument('--lr_decay_factor', default=0.95, type=float, help='ratiioi by which lr is decreased')
    parser.add_argument('--lr_lower_bound', default=0.00005, type=float, help='lowest lr allowed')
    parser.add_argument('--epochs', default=300, type=int, help='Number of sweeps over the dataset to train')
    parser.add_argument('--b_sizes', default=[32], nargs='*', type=int, help='batch sizes to search over')
    parser.add_argument('--l_rates', default=[0.0001], nargs='*', type=float, help='learning rates to search over')
    parser.add_argument('--depths', default=[5], nargs='*', type=int, help='number of layers to search over')
    parser.add_argument('--widths', default=[128], nargs='*', type=int, help='model widths to search over')
    parser.add_argument('--output_sets', default=1, type=int, help='how many output sets to predict')
    parser.add_argument('--base_gnn', default='gcn', type=str, help='architecture of first module of network')
    parser.add_argument('--rand_seeds', default=[1], nargs='*', type=int, help='ranom seeds to run')
    parser.add_argument('--final_penalty', default=3., type=float, help='final beta penalty value')
    parser.add_argument('--input_feat_dim', default=None, type=int, help='final beta penalty value')
    parser.add_argument('--n_eigs', default=8, type=int, help='----')
    parser.add_argument('--heads', default=8, type=int, help='erdos original') 
    parser.add_argument('--concat', action='store_true', help='concat heads') 

    parser.add_argument('--problem', default='cut', type=str, help='which problem to study')
    parser.add_argument('--k-clique-no', default=None, type=int, help='which problem to study')

    parser.add_argument('--features', default='random', type=str, help='which node features to use')
    parser.add_argument('--lap-method', default=None, type=str, help='which node features to use')
    parser.add_argument('--time_limit', default=None, type=float, help='gurobi time limit')

    parser.add_argument('--n_pertubations', default=1, type=int, help='number of different level set families to produce')
    parser.add_argument('--reweight', action='store_true', help='will print out what value the solutions attain')
    parser.add_argument('--permute_method', default=None, type=int, help='number of different level set families to produce')
    parser.add_argument('--window', default=None, type=int, help='window size')
    parser.add_argument('--penalty', default=None, type=float, help='beta penalty value')

    parser.add_argument('--k', default=4, type=int, help='cardinality size for fixed cardinality Lovasz')

    parser.add_argument('--aug', action='store_true', help='use node feature data augmentation')
    parser.add_argument('--testing', action='store_true', help='use node feature data augmentation')

    parser.add_argument('--compute-greedy', action='store_true', help='greedy max estimation')
    parser.add_argument('--compute-rand', action='store_true', help='greedy max estimation')
    parser.add_argument('--rand-prob', default=0.5, type=float, help='greedy max estimation')
    parser.add_argument('--F1', action='store_true', help='will print out what value the solutions attain')

    parser.add_argument('--debug', action='store_true', help='do not run wandb logging')

    parser.add_argument('--cardinality_const', default=5, type=int, help='add a cardinality constraint')
    parser.add_argument('--n_tries', default=None, type=int, help='use a non-exact ground truth for when no efficient method to solve exists')

    parser.add_argument('--optimizer', default='adam', type=str, help='which optimizer to use')

    # debugging
    parser.add_argument('--print_best', action='store_true', help='will print out what value the solutions attain')
    parser.add_argument('--test_freq', default=None, type=int, help='how many epochs per tets evaluation')
    parser.add_argument('--local',  action='store_true', help='which problem to study')
    parser.add_argument('--curriculum',  default=None, type=int, help='how many epochs per tets evaluation')

    parser.add_argument('--experiment', action='store_true', help='will print out what value the solutions attain')

    parser.add_argument('--extension',  default='lovasz', type=str, help='how many epochs per tets evaluation')
    parser.add_argument('--one_dim_extension',  default=None, type=str, help='how many epochs per tets evaluation')
    parser.add_argument('--k_min',  default=2, type=int, help='how many epochs per tets evaluation')
    parser.add_argument('--k_max',  default=4, type=int, help='how many epochs per tets evaluation')
    parser.add_argument('--num_sets',  default=None, type=int, help='how many epochs per tets evaluation')
    parser.add_argument('--new_diff',  action='store_true', help='which problem to study')

    parser.add_argument('--warmup',  action='store_true', help='which problem to study')
    parser.add_argument('--neural',  default='v1', type=str, help='for prototyping different versions of neural extension')
    parser.add_argument('--eig_sym',  action='store_true', help='which problem to study')

    parser.add_argument('--n_sets',  default=None, type=int, help='how many random spanning trees to use')
    parser.add_argument('--max_val',  default=5., type=float, help='how many random spanning trees to use')

    parser.add_argument('--chain_length',  default=6, type=int, help='how many random spanning trees to use')

    # TSP data parameterss
    parser.add_argument('--tsp_n_points', default=1000, type=int, help='dataset size (train, val anad test)')
    parser.add_argument('--tsp_max_size', default=20, type=int, help='largest size graph')
    parser.add_argument('--tsp_box_size', default=2., type=float, help='choose box points lie within')

    #datasets and saving
    parser.add_argument('--dataset_names', default=["ENZYMES"] , nargs='*', type=str, help='datasets to run over')
    parser.add_argument('--dataset_scale', default=1., type=float, help='proportion of dataset to use')
    parser.add_argument('--data_root', default='/data/scratch/joshrob/data', type=str, help='root directory where data is found')
    parser.add_argument('--save_root', default='/data/scratch/joshrob/comb-opt', type=str, help='root directory where results are saved')
    parser.add_argument('--save_name', default=None, type=str, help='save filename')
    parser.add_argument('--reinforce', action='store_true', help='Reinforce baseline')
    parser.add_argument('--num_reinforce_samples', default=200, type=int, help="Number of samples that reinforce is trained on in each epoch")
    parser.add_argument('--straight_through', action='store_true', help='Straighthrough baseline')
    parser.add_argument('--erdos', action ='store_true', help='Erdos baseline')
    parser.add_argument('--num_erdos_samples', default=1000, type=int)
    parser.add_argument('--erdos_penalty', default=1.5, type=float, help="Value of the penalty coefficient for Erdos")
    parser.add_argument('--straight_through_samples',action= 'store_true', help='Sample Sets for straighthrough instead of fixing')
    parser.add_argument('--num_st_samples', default = 1000, type=int, help= "number of straight through samples")
    parser.add_argument('--real_nfe', action='store_true', help= "straight through samples")
    parser.add_argument('--save_evals', action='store_true', help= "straight through samples")

    parser.add_argument('--bounded_k', default=3, type=int, help="Value of k for bounded cardinality extension")
    parser.add_argument('--reinforce_with_baseline', action='store_true', help='Use Reinforce with baseline')
    parser.add_argument('--ER_scale_experiment', action='store_true', help = 'Erdos renyi scaling and time experiment')
    parser.add_argument('--ER_howmany', default=10, type=int, help='how many graphs per size')
    parser.add_argument('--ER_graph_size_ub', default=500, type=int, help='how many graphs per size')
    parser.add_argument('--ER_prob',default=0.75, type=float, help="parameter of ER model")
    parser.add_argument('--doubly_nonnegative', action = 'store_true', help= 'Completely positive matrix')


    parser.add_argument('--early_stop', action='store_true', help= "straight through samples")
    parser.add_argument('--patience', default=30, type=int, help="Value of k for bounded cardinality extension")

    args = parser.parse_args(args=[])
    
    return args


class baseline_template(torch.nn.Module):
    def __init__(self, conf):
        super(baseline_template, self).__init__()
        self.conf = conf
        self.args = fetch_nsfe_args()
        self.overwrite_relevant_args()
        self.overwrite_baseline_specifc_args()
        self.set_baseline_type()


    def overwrite_relevant_args(self):
        self.args.epochs = self.conf.training.num_epochs
        self.args.save_evals_epoch = False
        self.args.test_freq = self.args.epochs + 1
        self.args.b_sizes = self.conf.training.batch_size
        self.args.depths = self.conf.model.depths
        self.args.l_rates = self.conf.training.learning_rate
        self.args.widths = self.conf.model.widths
        self.args.base_gnn = 'gat'
        self.args.features = 'one' 
        self.args.input_feat_dim = 1
        self.args.rand_seeds = self.conf.training.seed
        self.args.debug = False
        self.args.problem = self.conf.model.problem
        
        if (self.args.compute_greedy is True) or (self.args.compute_rand is True):
            self.args.epochs=0 
        


    def set_baseline_type(self):
        pass

    def overwrite_baseline_specifc_args(self):
        pass

    def forward(self,  
                query_batch_data, 
                query_batch_data_node_sizes, 
                query_batch_data_edge_sizes, 
                query_batch_adj, 
                corpus_batch_data, 
                corpus_batch_data_node_sizes, 
                corpus_batch_data_edge_sizes,  
                corpus_batch_adj, 
                diagnostic_mode=False):
        
        data = corpus_batch_data
        if self.args.aug is True:
            assert False, print(f"Weirdness alert!!")
            data.x=data.x + 0.1*torch.rand(data.x.shape).to(device)

        warmup = False 
        output = self.net(data, self.args, warmup)

        loss = output["loss"] 
        # NOTE: Baseline code return -ve of the detected max clique size (it's always integer)
        best_solutions = torch.stack(output["best_sets"], dim=0).data * -1 
        return loss, best_solutions
    
class SFE(baseline_template):
    def __init__(self, conf):
        super(SFE, self).__init__(conf) 
    
    def set_baseline_type(self):
        numlayers = self.args.depths
        hidden_1 = self.args.widths   
        self.net = ExtensionNet(numlayers, hidden_1, self.args)
        self.net.to(self.conf.training.device).reset_parameters()
        self.net.deterministic=False
    
    def overwrite_baseline_specifc_args(self):
        self.args.extension = 'lovasz'

class NSFE(baseline_template):
    def __init__(self, conf):
        super(NSFE, self).__init__(conf) 
    
    def set_baseline_type(self):
        numlayers = self.args.depths
        hidden_1 = self.args.widths   
        self.net = ExtensionNet(numlayers, hidden_1, self.args)
        self.net.to(self.conf.training.device).reset_parameters()
    
    def overwrite_baseline_specifc_args(self):
        self.args.extension = 'neural'
        self.args.one_dim_extension = 'lovasz'
        self.args.neural = 'v4' 
        self.args.n_sets = 4

class REINFORCE(baseline_template):
    def __init__(self, conf):
        super(REINFORCE, self).__init__(conf) 
    
    def set_baseline_type(self):
        numlayers = self.args.depths
        hidden_1 = self.args.widths   
        self.net = ReinforceNet(numlayers, hidden_1, self.args)
        self.net.to(self.conf.training.device).reset_parameters()
    
    def overwrite_baseline_specifc_args(self):
        self.args.reinforce = True
        self.args.num_reinforce_samples = 1000

class ST(baseline_template):
    def __init__(self, conf):
        super(ST, self).__init__(conf) 
    
    def set_baseline_type(self):
        numlayers = self.args.depths
        hidden_1 = self.args.widths   
        self.net = STNet(numlayers, hidden_1, self.args)
        self.net.to(self.conf.training.device).reset_parameters()
    
    def overwrite_baseline_specifc_args(self):
        self.args.straight_through = True

