import argparse
from ast import arg
import sys, time
import os
import pandas as pd
import pickle
import torch
from pyscipopt import Model, Branchrule, SCIP_RESULT
sys.path.append('../')
import  src.utils as utils
import src.model as model
from src.utils import set_device_seed, set_scip
from src.logger import Logger
logger = Logger.logger
import shutil
from collections import deque,OrderedDict
import dgl
import heapq
import numpy as np
from multiprocessing.dummy import Pool as ThreadPool
import time
import copy
import pickle

train_loss_weight = 1
test_loss_weight = 1
    
problem_name = {
    'setcover': 'setcover_400r_1000c_0.05d_100mc_0se',
    'cauctions': 'cauctions_0se',
    'facility': 'facility_0se',
    'indset': 'indset_400n_4a_0se',
    'gisp': 'gisp',
    'wpms': 'wpms',
    'fcmcnf': 'fcmcnf',
}

easy_instance_folders = {
    'setcover': 'setcover_400r_1000c_0.05d_100mc_0se/transfer_500r',
    'cauctions': 'cauctions_0se/transfer_100_500',
    'facility': 'facility_0se/transfer_100_100_5',
    'indset': 'indset_400n_4a_0se/transfer_500n',
    'gisp': 'gisp/transfer_easy',
    'wpms': 'wpms/transfer_easy',
    'fcmcnf': 'fcmcnf/transfer_easy',
}

mdeium_instance_folders = {
    'setcover': 'setcover_400r_1000c_0.05d_100mc_0se/transfer_1000r',
    'cauctions': 'cauctions_0se/transfer_200_1000',
    'facility': 'facility_0se/transfer_200_100_5',
    'indset': 'indset_400n_4a_0se/transfer_1000n',
    'gisp': 'gisp/transfer_medium',
    'wpms': 'wpms/transfer_medium',
    'fcmcnf': 'fcmcnf/transfer_medium',
}

hard_instance_folders = {
    'setcover': 'setcover_400r_1000c_0.05d_100mc_0se/transfer_2000r',
    'cauctions': 'cauctions_0se/transfer_300_1500',
    'facility': 'facility_0se/transfer_400_100_5',
    'indset': 'indset_400n_4a_0se/transfer_1500n',
    'gisp': 'gisp/transfer_hard',
    'wpms': 'wpms/transfer_hard',
    'fcmcnf': 'fcmcnf/transfer_hard',
}


def check_deque(temp_deque):
    check_list = deque()
    for one_data in temp_deque:
        if one_data is None:
            check_list.append(None)
        else:
            check_list.append("data")
    
    print("check_list:",check_list)

def geo_mean(x, shift=1.0):
    # 1-shift geometry product
    a = torch.tensor(x).float() + shift
    return a.log().mean().exp() - shift

def num_mean(x):
    if len(x) == 0:
        return 0
    return sum(x) / len(x)

class PolicyBranching_baseline(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0
        self.state_buffer = {}
        self.khalil_root_buffer = {}
        self.branch_count = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        cands, *_ = self.model.getPseudoBranchCands()

        gcn_state = utils.extract_state(self.model, self.state_buffer)

        g = utils.graph_transform(gcn_state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2

        self.branch_count = self.branch_count + 1

        return {'result': result}

class PolicyBranching_38(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_38(self.model)

        g = utils.graph_transform_new_38(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class PolicyBranching_39(Branchrule):
    # Get accurate data for imitation learning and supervised learning
    # 为模仿学习和监督学习获得准确的数据

    def __init__(self, scip, agent, device, args):
        super().__init__()
        self.model = scip
        self.device = device
        self.agent = agent
        self.args = args
        self.count = 0
        self.dom_reduction = 0
        self.cut_off = 0


    def branchinit(self):
        self.root_buffer = {}
        self.count = 0

    
    def branchexeclp(self, allowaddcons):

        result = SCIP_RESULT.DIDNOTRUN
        state = utils.extract_state_new_39(self.model)

        g = utils.graph_transform_new_39(state).to(self.device)

        agent_out = self.agent(g)

        agent_out = agent_out.nodes['v'].data['s']

        cands, *_ = self.model.getPseudoBranchCands()

        action_set = [c.getCol().getLPPos() for c in cands]
        
        score = agent_out[action_set]
        best_cand = int(score.argmax())

        self.model.branchVar(cands[best_cand])
        result = SCIP_RESULT.BRANCHED

        if result == SCIP_RESULT.REDUCEDDOM:
            self.dom_reduction += 1
        elif result == SCIP_RESULT.CUTOFF:
            self.cut_off += 2
        return {'result': result}

class ModelSelect():
    def __init__(self, agent, device, args, select_instances, valid_instances, filename_dir):
        self.epoch_info_dict = None
        self.instances = None
        self.agent = agent
        self.device = device
        self.args = args
        self.select_instances = select_instances
        self.valid_instances = valid_instances
        self.filename_dir = filename_dir

        self.select_from_train_data()

    def select_from_train_data(self):
        
        with open(self.filename_dir, 'rb') as f:
            self.epoch_info_dict = pickle.load(f)

        print("################ load train data success ################")

        score_dict = {}
        for key, value in self.epoch_info_dict.items():
            branch_score = train_loss_weight * value['train_loss'] + test_loss_weight * value['test_loss']
            score_dict[key] = branch_score
            self.epoch_info_dict[key]['branch_score_1'] = branch_score

        sorted_dict = dict(sorted(score_dict.items(), key=lambda item: item[1]))
        
        top_10_dict = {}
        for i, (key, value) in enumerate(sorted_dict.items()):
            if i < 10:
                top_10_dict[key] = value
                

        print("################ get branch score 1 success ################")

        # for key, value in score_dict.items():
        #     ave_sol_time, ave_nodes = self.valid_evaluate(key)
        #     self.epoch_info_dict[epoch_id]['ave_sol_time'] = ave_sol_time
        #     self.epoch_info_dict[epoch_id]['ave_nodes'] = ave_nodes

        # for key,value in top_10_dict.items():
        #     print(f"epoch {key} : branch_score:{value}")
        
        
        for key,value in top_10_dict.items():
            branch_score_2 = train_loss_weight * self.epoch_info_dict[key]['train_loss'] + test_loss_weight * self.epoch_info_dict[key]['test_loss']
            self.epoch_info_dict[key]["branch_score_2"] = branch_score_2
            top_10_dict[key] = branch_score_2
        
        top_10_dict = dict(sorted(top_10_dict.items(), key=lambda item: item[1]))
        
        print("################ get branch score 2 success ################")

        for key,value in top_10_dict.items():
            print(f"epoch {key} : branch_score:{value}")

        for i, (key, value) in enumerate(top_10_dict.items()):
            if i < 1:
                print("######################################################################################")
                print(f"the {args.ins_type} {args.method} selected epoch is {key}")
                print("######################################################################################")
                break

        key = args.check_point
        ave_sol_time, ave_nodes = self.valid_evaluate(key)
        print(f"best {key} \t acc:{str(self.epoch_info_dict[key]['test_accuracy'])[:4]} \t branch score:{str(self.epoch_info_dict[key]['branch_score_1'])[:4]} \t sol_time:{str(ave_sol_time)[:7]} \t nodes:{ave_nodes}")


        for i, (key, value) in enumerate(top_10_dict.items()):
            ave_sol_time, ave_nodes = self.valid_evaluate(key)
            print(f"epoch{key} \t acc:{str(self.epoch_info_dict[key]['test_accuracy'])[:4]} \t branch score:{str(self.epoch_info_dict[key]['branch_score_2'])[:4]} \t sol_time:{str(ave_sol_time)[:7]} \t nodes:{ave_nodes}")


    def valid_evaluate(self, epoch_id):
        check_point_path = f'../check_points/{problem_name[args.ins_type]}/{args.method}/{epoch_id}.pt'
        self.agent.load_state_dict(torch.load(check_point_path, map_location=self.device))

        statuss = []
        sol_time = []
        nodes = []
        fair_nodes = []
        gaps = []

        for instance in self.valid_instances:
            model = Model()
            model.hideOutput()
            model.readProblem(f"{instance}")
            set_scip(model, args.scip_seed, restart=False, separator=False, primal_heuristic=False)
            model.setIntParam('timing/clocktype', 2)  # 1: CPU user seconds, 2: wall clock time
            model.setRealParam('limits/time', args.time_limit)

            if args.method in ['baseline']:
                brancher = PolicyBranching_baseline(model, agent, device, args)

            elif args.method == 'new_38':
                brancher = PolicyBranching_38(model, agent, device, args)

            elif args.method == 'new_39':
                brancher = PolicyBranching_39(model, agent, device, args)

            else:
                raise Exception("policy branching method is not find")

            model.includeBranchrule(brancher, "Evaluate", "Policy branching on variable",
                                    priority=99999, maxdepth=-1, maxbounddist=1)

            model.optimize()
            status = model.getStatus()
            val = model.getObjVal()
            solving_time = model.getSolvingTime()
            num_nodes = model.getNNodes()
            fair_node = num_nodes + brancher.cut_off + brancher.dom_reduction
            gap = model.getGap()

            statuss.append(status)
            sol_time.append(solving_time)
            nodes.append(num_nodes)
            fair_nodes.append(fair_node)
            gaps.append(gap)
            out_instance=instance[:]

            model.freeProb()


        ave_sol_time = num_mean(sol_time)
        ave_nodes = num_mean(nodes)

        return ave_sol_time, ave_nodes



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--method', help='Branch method.', default='baseline')

    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--check_point', type=str, default="bc_89")
    parser.add_argument('--device_id', '-d', type=int, default=0)
    # 评估测试实例的参数
    parser.add_argument('--scip_seed', type=int, default=0)
    parser.add_argument('--ins_type', type=str, help='instances directory', default='setcover')
    parser.add_argument('--select_config', type=str, help='instances directory', default='medium')
    parser.add_argument('--valid_config', type=str, help='instances directory', default='hard')
    parser.add_argument('--time_limit', type=int, help='limit time for solving each instance', default=3000)
    parser.add_argument('--num_instance', type=int, default=2)

    args = parser.parse_args()
    device = set_device_seed(args)



    if args.select_config == "easy":
        select_instances_flag = easy_instance_folders[f'{args.ins_type}']
    elif args.select_config == "medium":
        select_instances_flag = mdeium_instance_folders[f'{args.ins_type}']
    elif args.select_config == "hard":
        select_instances_flag = hard_instance_folders[f'{args.ins_type}']
    else:
        raise Exception("args select_config is not find")

    if args.valid_config == "easy":
        valid_instances_flag = easy_instance_folders[f'{args.ins_type}']
    elif args.valid_config == "medium":
        valid_instances_flag = mdeium_instance_folders[f'{args.ins_type}']
    elif args.valid_config == "hard":
        valid_instances_flag = hard_instance_folders[f'{args.ins_type}']
    else:
        raise Exception("args valid_config is not find")


    filename_dir = f'../check_points/{problem_name[args.ins_type]}/{args.method}/train_data.pkl'


    if args.method == 'baseline':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method == 'new_38':
        agent = model.GCNN_Net(v_dim=17).to(device)

    elif args.method == 'new_39':
        agent = model.GCNN_Net(v_dim=17).to(device)

    else:
        raise Exception("model method is not find")

    select_instances = [os.path.join(f'../data/instances/{select_instances_flag}',
                              f'instance_{i+1}.lp') for i in range(args.num_instance)]

    valid_instances = [os.path.join(f'../data/instances/{valid_instances_flag}',
                              f'instance_{i+1}.lp') for i in range(args.num_instance)]

    epoch_select_model = ModelSelect(agent, device, args, select_instances, valid_instances, filename_dir)


