import torch.multiprocessing.queue
from AbstractClass.AbstractStaticFactory import AbstractStaticFactory
from SourceCode.LoggerDegree import DegreeLogger
from SourceCode.LoggerExist import ExistLogger
from SourceCode.MetaGraphSketch import MetaGraphSketch
from SourceCode.ModelModule.AttentionMatrix import *
from SourceCode.ModelModule.EmbeddingModule import *
from SourceCode.ModelModule.MemoryMatrix import *
from SourceCode.ModelModule.RefineModule import *
from SourceCode.ModelModule.DecodeModule import *
from SourceCode.Model import *
from SourceCode.TaskRelatedClasses.DecoratorSupportGenerator import ShuffleDecoratorSupportGenerator, \
    SkewDecoratorSupportGenerator, ZipfDecoratorGenerator
from SourceCode.TaskRelatedClasses.QueryGenerator import QueryGeneratorForExist, SimpleQueryGenerator, \
    QueryGeneratorForDegree
from SourceCode.TaskRelatedClasses.SupportGenerator import ConstructionAnonymousSupportGenerator, BasicAnonymousSupportGenerator
from SourceCode.TaskRelatedClasses.TaskConsumer import *
from SourceCode.Logger import *
from SourceCode.TaskRelatedClasses.TaskProducer import TaskProducer
from SourceCode.ModelModule.LossFunc import *
import os
import random
import numpy as np
from torch.multiprocessing import Manager, Queue, Process

from Utils.Util import convert_base_sketch_accelerate

"""
extension type:
0 : basic for weight predict
1 : exist predict
"""


class Factory(AbstractStaticFactory):

    @staticmethod
    def seed_everything(seed=0):
        '''
        设置整个开发环境的seed
        :param seed:
        :param device:
        :return:
        '''

        random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        # some cudnn methods can be random even after fixing the seed
        # unless you tell it to be deterministic
        torch.backends.cudnn.deterministic = True

    @staticmethod
    def construct_exist_model_from_basic(basic_model: BasicModel, construct_config):
        Factory.seed_everything()
        print('construct extension model for exist from basic model....')

        exist_decode_net = None
        exist_model = None
        if construct_config['decode_exist_class'] == "ResExistDecodeNet":
            read_dim = basic_model.weight_decode_net.in_dim
            exist_decode_hidden_layer_size = construct_config['exist_decode_hidden_layer_size']
            exist_decode_net = ResExistDecodeNet(read_dim, exist_decode_hidden_layer_size)
            exist_model = ExtensionModelForExist(basic_model=basic_model, decode_net=exist_decode_net)
        assert exist_decode_net is not None and exist_model is not None, " Initialization failure"
        mem_shape = exist_model.memory_matrix.memory_matrix.shape
        memory_size = 1
        for dim in mem_shape:
            memory_size *= dim

        print('This model have', memory_size * 4 / 1024, 'KB memory')

        return exist_model

    @staticmethod
    def init_exist_MGS(basic_model_path, prod, extension_config):
        Factory.seed_everything()
        device = Factory.init_device(extension_config)
        loss_func = Factory.init_loss_func(extension_config['construct_config']['loss_class'])
        basic_model = torch.load(basic_model_path, map_location=torch.device('cpu'))
        # basic_model.memory_matrix.device = device
        exist_model = Factory.construct_exist_model_from_basic(basic_model, extension_config['construct_config'])
        # exist_decode_optimizer = Factory.init_optimizer(exist_model.exist_decode_net, config=extension_config)
        exist_decode_optimizer = Factory.init_optimizer(exist_model.decode_net, config=extension_config)

        # input dim is source dim + end dim
        task_producer, task_consumer,extension_type = Factory.init_producer_consumer(device, config=extension_config,
                                                                      extension_type=1,
                                                                      input_dim=exist_model.embedding_net_1.in_dim * 2)
        test_meta_task_group_list, meta_task_group_discribe_list = task_producer.produce_test_task()

        logger = Factory.init_logger(test_meta_task_group_list, meta_task_group_discribe_list, loss_func,
                                     prod, extension_config, extension_type=1)
        MGS = MetaGraphSketch(task_producer, task_consumer, exist_model, loss_func, device, exist_decode_optimizer,
                              logger)
        print('init exist MSG by loading base model done!')
        return MGS

    @staticmethod
    def construct_out_degree_model_from_basic(basic_model: BasicModel, construct_config):
        Factory.seed_everything()
        print('construct extension model for exist from basic model....')
        degree_decode_net = None
        degree_model = None
        if construct_config['decode_exist_class'] == "ResDegreeDecodeNet":
            previous_memory = basic_model.memory_matrix
            memory_matrix = BasicMemoryMatrixAndCM12ForDegree(previous_memory.row_dim,previous_memory.col_dim,
                                                              previous_memory.edge_embedding_dim,previous_memory.device)
            read_dim = memory_matrix.read_dim
            degree_decode_hidden_layer_size = construct_config['degree_decode_hidden_layer_size']
            degree_decode_net = ResDegreeDecodeNet(read_dim, degree_decode_hidden_layer_size)
            degree_model = ExtensionModelForDegree(basic_model=basic_model, decode_net=degree_decode_net,
                                                   memory_matrix=memory_matrix)
        assert degree_decode_net is not None and degree_model is not None, " Initialization failure"
        mem_shape = degree_model.memory_matrix.memory_matrix.shape
        memory_size = 1
        for dim in mem_shape:
            memory_size *= dim
        print('This model have', memory_size * 4 / 1024, 'KB memory')
        return degree_model

    @staticmethod
    def init_out_degree_MGS(basic_model_path, prod, extension_config):
        Factory.seed_everything()
        device = Factory.init_device(extension_config)
        loss_func = Factory.init_loss_func(extension_config['construct_config']['loss_class'])
        basic_model = torch.load(basic_model_path, map_location=torch.device('cpu'))
        degree_model = Factory.construct_out_degree_model_from_basic(basic_model, extension_config['construct_config'])
        degree_decode_optimizer = Factory.init_optimizer(degree_model.decode_net, config=extension_config)
        # input dim is source dim + end dim

        task_producer, task_consumer,extension_type = Factory.init_producer_consumer(device, config=extension_config,
                                                                      extension_type=2,
                                                                      input_dim=degree_model.embedding_net_1.in_dim * 2)
        # get col address for node
        base_support_generator = task_producer.support_generator.get_base_support_generator()
        filtered_train_node_np = base_support_generator.filtered_train_node_np
        with torch.no_grad():
            node_input = torch.tensor(filtered_train_node_np).float()
            node_represetation = degree_model.embedding_net_1(node_input)
            node_address = degree_model.attention_matrix(degree_model.refine_net(node_represetation))
        node_address = node_address.numpy()
        address_dic = {}
        for i in range(filtered_train_node_np.shape[0]):
            address_dic[filtered_train_node_np[i].tobytes()] = node_address[i]
        task_producer.query_generator.set_node_address_dic(address_dic)
        test_meta_task_group_list, meta_task_group_discribe_list = task_producer.produce_test_task()
        logger = Factory.init_logger(test_meta_task_group_list, meta_task_group_discribe_list, loss_func,
                                     prod, extension_config, extension_type=2)
        MGS = MetaGraphSketch(task_producer, task_consumer, degree_model, loss_func, device, degree_decode_optimizer,
                              logger)
        print('init degree MSG by loading base model done!')
        return MGS

    @staticmethod
    def init_MGS_by_load_model(model_path, prod, config):
        Factory.seed_everything()
        device = Factory.init_device(config=config)
        model = torch.load(model_path, map_location=torch.device('cpu'))
        model = convert_base_sketch_accelerate(model)
        mem_shape = model.memory_matrix.memory_matrix.shape
        memory_size = 1
        for dim in mem_shape:
            memory_size *= dim

        print('This model have', memory_size * 4 / 1024, 'KB memory')
        # model.memory_matrix.device = device
        print('load model done!')
        loss_func = Factory.init_loss_func(config['factory_config']['loss_class'])
        optimizer = Factory.init_optimizer(model, config=config)
        input_dim = model.embedding_net_1.in_dim * 2
        # extension_type==-1 and init_pro_con func will set the correct value for it
        task_producer, task_consumer,extension_type = Factory.init_producer_consumer(device, config=config, extension_type=-1,
                                                                      input_dim=input_dim)
        test_meta_task_group_list, meta_task_group_discribe_list = task_producer.produce_test_task()

        logger = Factory.init_logger(test_meta_task_group_list, meta_task_group_discribe_list, loss_func,
                                     prod, config=config, extension_type=extension_type)
        MGS = MetaGraphSketch(task_producer, task_consumer, model, loss_func, device, optimizer, logger)
        print('init MSG by loading model done!')
        return MGS

    @staticmethod
    def init_basic_MGS(prod, config):
        Factory.seed_everything()
        print('init MGS....')
        model = Factory.init_model(config=config)
        mem_shape = model.memory_matrix.memory_matrix.shape
        memory_size = 1
        for dim in mem_shape:
            memory_size *= dim

        print('This model have', memory_size * 4 / 1024, 'KB memory')
        loss_func = Factory.init_loss_func(config['factory_config']['loss_class'])
        optimizer = Factory.init_optimizer(model, config=config)
        device = Factory.init_device(config=config)
        # different dataset determine different producer
        dest_input_dim = config["dim_config"]["dest_input_dim"]
        source_input_dim = config["dim_config"]["source_input_dim"]
        input_dim = dest_input_dim + source_input_dim

        task_producer, task_consumer,extension_type = Factory.init_producer_consumer(device, config=config, extension_type=0,
                                                                      input_dim=input_dim)
        test_meta_task_group_list, meta_task_group_discribe_list = task_producer.produce_test_task()
        logger = Factory.init_logger(test_meta_task_group_list, meta_task_group_discribe_list, loss_func, prod,
                                     config=config, extension_type=0)
        MGS = MetaGraphSketch(task_producer, task_consumer, model, loss_func, device, optimizer, logger)
        print('init MSG done!')
        return MGS



    @staticmethod
    def init_producer_consumer(device, config, extension_type, input_dim):
        dataset_name = config["data_config"]["dataset_name"]
        skew_lower = config['data_config']['skew_lower']
        skew_upper = config['data_config']['skew_upper']
        item_lower = config['data_config']['item_lower']
        item_upper = config['data_config']['item_upper']
        dataset_path = config['data_config']['dataset_path']
        test_task_item_size = config['logger_config']['test_task_item_size']
        test_task_stream_length = config['logger_config']['test_task_stream_length']
        test_task_stream_length_ratio = config['logger_config']['test_task_stream_length_ratio']
        test_task_zipf_param = config['logger_config']['test_task_zip_param']
        test_task_group_size = config['logger_config']['test_task_group_size']
        zipf_param_upper = config['data_config']['zipf_param_upper']
        zipf_param_lower = config['data_config']['zipf_param_lower']
        assert test_task_stream_length is None or test_task_stream_length_ratio is None,\
            'only one of them can be assigned'
        task_producer, task_consumer, decorator_support_generator,query_generator = None, None, None, None
        if dataset_name == "BasicSketch" and (extension_type == 0 or extension_type == -1):
            base_support_generator = BasicAnonymousSupportGenerator( input_dim=input_dim,
                                                                           item_lower=item_lower, item_upper=item_upper)
            decorator_support_generator = ShuffleDecoratorSupportGenerator(base_support_generator)
            # skew is included in zipfDecoratorGenerator
            decorator_support_generator = ZipfDecoratorGenerator(base_support_generator=decorator_support_generator,
                                                                 zipf_param_upper=zipf_param_upper,
                                                                 zipf_param_lower=zipf_param_lower,
                                                                 skew_lower=skew_lower, skew_upper=skew_upper)
            decorator_support_generator.set_device(device)
            query_generator = SimpleQueryGenerator()
            extension_type = 0


        elif "BaseAdvancedSketch" in dataset_name and (extension_type == 0 or extension_type==-1):
            step_scale = config['train_config']['step_scale']
            total_step = config['train_config']['num_of_meta_task']
            if step_scale is None:
                total_step = None
            base_support_generator = ConstructionAnonymousSupportGenerator(data_path=dataset_path, input_dim=input_dim, step_scale=step_scale,
                                                                        total_step=total_step)
            base_support_generator.set_device(device)
            base_support_generator.set_sample_func('stream_length')
            decorator_support_generator = ShuffleDecoratorSupportGenerator(base_support_generator)
            decorator_support_generator = SkewDecoratorSupportGenerator(decorator_support_generator,
                                                                        skew_lower=skew_lower, skew_upper=skew_upper,
                                                                       )
            query_generator = SimpleQueryGenerator()
            extension_type = 0

        elif "ExistAdvancedSketch" in dataset_name and (extension_type == 1 or extension_type == -1):
            # set generate_fake_edge to True
            step_scale = config['train_config']['step_scale']
            total_step = config['train_config']['num_of_meta_task']
            if step_scale is None:
                total_step = None
            base_support_generator = ConstructionAnonymousSupportGenerator(data_path=dataset_path,
                                                                           input_dim=input_dim,
                                                                           item_lower=item_lower,
                                                                           item_upper=item_upper,
                                                                           generate_fake_edge=True,
                                                                           step_scale=step_scale,
                                                                           total_step=total_step)
            base_support_generator.set_device(device)
            base_support_generator.set_sample_func('stream_length')
            decorator_support_generator = ShuffleDecoratorSupportGenerator(base_support_generator)
            decorator_support_generator = SkewDecoratorSupportGenerator(decorator_support_generator,
                                                                        skew_lower=skew_lower,
                                                                        skew_upper=skew_upper)
            query_generator = QueryGeneratorForExist()
            extension_type = 1
        elif "DegreeAdvancedSketch" in dataset_name and (extension_type == 2 or extension_type==-1):
            step_scale = config['train_config']['step_scale']
            total_step = config['train_config']['num_of_meta_task']
            if step_scale is None:
                total_step = None
            base_support_generator = ConstructionAnonymousSupportGenerator(data_path=dataset_path,
                                                                           input_dim=input_dim,
                                                                           item_lower=item_lower,
                                                                           item_upper=item_upper,
                                                                           step_scale=step_scale,
                                                                           total_step=total_step)
            base_support_generator.set_sample_func('stream_length')
            base_support_generator.set_device(device)
            decorator_support_generator = ShuffleDecoratorSupportGenerator(base_support_generator)
            decorator_support_generator = SkewDecoratorSupportGenerator(decorator_support_generator,
                                                                        skew_lower=skew_lower,
                                                                        skew_upper=skew_upper)
            query_generator = QueryGeneratorForDegree()
            extension_type = 2

        task_producer = TaskProducer(decorator_support_generator, query_generator, device,
                                     test_task_zipf_param_list=test_task_zipf_param,
                                     test_task_length_list=test_task_stream_length,
                                     test_task_item_size_list=test_task_item_size,
                                     test_task_length_ratio_list=test_task_stream_length_ratio,
                                     test_task_group_size=test_task_group_size)
        task_consumer = TaskConsumer(device)
        assert task_consumer is not None
        assert task_producer is not None
        return task_producer, task_consumer,extension_type

    @staticmethod
    def init_logger(test_meta_task_group_list, meta_task_group_discribe_list, loss_func, prod, config, extension_type):
        logger = None
        if extension_type == 0:
            logger = BasicLogger(test_meta_task_group_list, meta_task_group_discribe_list,
                                 dataset_name=config['data_config']['dataset_name'],
                                 config=config,
                                 loss_func=loss_func,
                                 eval_query_batch_size=config['logger_config']['eval_query_batch_size'],
                                 eval_support_batch_size=config['logger_config']['eval_support_batch_size'],
                                 flush_gap=config['logger_config']['flush_gap'],
                                 train_comment=config['data_config']['train_comment'],
                                 eval_gap=config['logger_config']['eval_gap'],
                                 save_gap=config['logger_config']['save_gap'],
                                 prod_env=prod)
        elif extension_type == 1:
            logger = ExistLogger(test_meta_task_group_list, meta_task_group_discribe_list,
                                 dataset_name=config['data_config']['dataset_name'],
                                 config=config,
                                 loss_func=loss_func,
                                 eval_query_batch_size=config['logger_config']['eval_query_batch_size'],
                                 eval_support_batch_size=config['logger_config']['eval_support_batch_size'],
                                 flush_gap=config['logger_config']['flush_gap'],
                                 train_comment=config['data_config']['train_comment'],
                                 eval_gap=config['logger_config']['eval_gap'],
                                 save_gap=config['logger_config']['save_gap'],
                                 prod_env=prod)
        elif extension_type == 2 :
            logger = DegreeLogger(test_meta_task_group_list, meta_task_group_discribe_list,
                                 dataset_name=config['data_config']['dataset_name'],
                                 config=config,
                                 loss_func=loss_func,
                                 eval_query_batch_size=config['logger_config']['eval_query_batch_size'],
                                 eval_support_batch_size=config['logger_config']['eval_support_batch_size'],
                                 flush_gap=config['logger_config']['flush_gap'],
                                 train_comment=config['data_config']['train_comment'],
                                 eval_gap=config['logger_config']['eval_gap'],
                                 save_gap=config['logger_config']['save_gap'],
                                 prod_env=prod)
        assert logger is not None, "logger initialization failure"
        return logger

    @staticmethod
    def init_device(config):
        cuda_num = config['train_config']['cuda_num']
        if cuda_num == -1:
            device = torch.device('cpu')
        else:
            os.environ['CUDA_VISIBLE_DEVICES'] = str(cuda_num)
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        return device

    @staticmethod
    def init_loss_func(loss_class):
        return eval(loss_class + "()")

    @staticmethod
    def init_optimizer(model, config):
        lr = config['train_config']['lr']
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
        return optimizer

    @staticmethod
    def init_model(config):
        row_dim = config["dim_config"]["row_dim"]
        col_dim = config["dim_config"]["col_dim"]
        source_embedding_dim = config["dim_config"]["source_embedding_dim"]
        dest_embedding_dim = config["dim_config"]["dest_embedding_dim"]
        edge_embedding_dim = config["dim_config"]["edge_embedding_dim"]
        source_refined_dim = config["dim_config"]["source_refined_dim"]
        dest_refined_dim = config["dim_config"]["dest_refined_dim"]
        source_input_dim = config["dim_config"]["source_input_dim"]
        source_embedding_hidden_layer_size = config["hidden_layer_config"]["source_embedding_hidden_layer_size"]
        source_refined_hidden_layer_size = config["hidden_layer_config"]["source_refined_hidden_layer_size"]
        edge_embedding_hidden_layer_size = config["hidden_layer_config"]["edge_embedding_hidden_layer_size"]

        model = None
        if config['factory_config']['model_class'] == "BasicModel":
            assert row_dim == col_dim and source_refined_dim == dest_refined_dim, \
                "error ! This model should have same dim in row&" \
                "col source_refined_dim&dest_refined_dim"
            # assert edge_embedding_dim is None or edge_embedding_dim == source_embedding_dim + dest_embedding_dim, \
            #     'error ! In this Model , edge_embedding dim should equal to the sum of source and dest'
            assert source_embedding_dim == dest_embedding_dim, 'This model must have same dim in source&dest embedding'
            embedding_net_1 = EmbeddingNet(source_input_dim, source_embedding_dim, source_embedding_hidden_layer_size)
            embedding_net_2 = EmbeddingNet(source_embedding_dim, edge_embedding_dim//2,
                                           edge_embedding_hidden_layer_size)
            refine_net = RefineNet(source_embedding_dim, source_refined_dim, source_refined_hidden_layer_size)
            attention_matrix = Factory.init_attention_matrix(source_refined_dim, row_dim,config)
            memory_matrix, weight_decode_net = Factory.init_memory_matrix_and_decode_nets(config)
            model = BasicModel(attention_matrix, embedding_net_1,embedding_net_2, weight_decode_net, refine_net, memory_matrix)

        assert model is not None, "model have not been set"
        return model

    @staticmethod
    def init_attention_matrix(refined_dim, slot_dim, config):
        return eval(config["factory_config"]["attention_class"] + "(refined_dim,slot_dim)")

    @staticmethod
    def init_memory_matrix_and_decode_nets(config):
        row_dim = config["dim_config"]["row_dim"]
        col_dim = config["dim_config"]["col_dim"]
        source_embedding_dim = config["dim_config"]["source_embedding_dim"]
        dest_embedding_dim = config["dim_config"]["dest_embedding_dim"]
        edge_embedding_dim = config["dim_config"]["edge_embedding_dim"]
        memory_class = config["factory_config"]["memory_class"]
        sparse_degree = config["factory_config"]["sparse_degree"]
        weight_decode_class = config["factory_config"]["decode_weight_class"]
        memory_matrix = None
        # this value will be used in eval expression
        read_dim = None
        device = Factory.init_device(config=config)
        # choose memory_class
        if memory_class == "BasicMemoryMatrixAndCM12":
            assert sparse_degree is None, "error! " + "this memory_class " + memory_class + "do not need a sparse_degree"
            if edge_embedding_dim is None:
                edge_embedding_dim = source_embedding_dim + dest_embedding_dim
            # read_dim = edge_embedding_dim + edge_embedding_dim + 1 + 1 + 1 + 1+edge_embedding_dim
            read_dim = edge_embedding_dim + edge_embedding_dim + 1 + 1 + 1 + 1

            memory_matrix = BasicMemoryMatrixAndCM12(row_dim, col_dim, edge_embedding_dim, device)
        assert memory_matrix is not None, "memory matrix have not been set"
        weight_decode_hidden_layer_size = config["hidden_layer_config"]["weight_decode_hidden_layer_size"]
        weight_decode_net = eval(weight_decode_class + "(read_dim, weight_decode_hidden_layer_size)")
        return memory_matrix, weight_decode_net
