from __future__ import annotations
import argparse
import yaml
from easydict import EasyDict
import os
import torch
import jinja2
from torchkge.data_structures import KnowledgeGraph

from src.autoregressive_models.arm import AutoRegressiveModel
from src.autoregressive_models.arm_convolution import ARMConvolution
from src.autoregressive_models.arm_transformer import ARMTransformer
from src.models import GenerativeModel, KBCWrapper, get_nbf_wrapper
from src.utils import load_fb15k237, load_wn18rr, get_prior_frequencies


def get_arm_model(config: dict, number_of_entities: int, number_of_relations: int, kg_train: KnowledgeGraph):
    prior_frequencies = get_prior_frequencies(kg_train).to(config['device']) if config['prior_type'] == 'frequency' else None
    number_of_relations += number_of_relations
    print("NUMBER OF RELATIONS arm", number_of_relations)
    arm_model = AutoRegressiveModel
    print(config['model_type'])
    match config['model_type']:
        case "arm_transformer":
            arm_model = ARMTransformer(
                config['num_blocks'],
                config['embedding_dimension'],
                config['dropout_prob'],
                config['num_heads'],
                config['num_neurons'],
                number_of_entities,
                number_of_relations)

        case "arm_convolution":
            arm_model = ARMConvolution(
                config['kernel_size'],
                config['m'],
                config['dropout_prob'],
                number_of_entities,
                number_of_relations,
                config['embedding_dimension'])
        case _:
            raise ValueError("Model Type Unknown")

    model = GenerativeModel(
        arm_model,
        config['embedding_dimension'],
        number_of_entities,
        number_of_relations,
        prior_frequencies,
        config['prior_type'],
    )
    return model

def load_pretrained_model(config, number_of_entities, number_of_relations, kg_train):

    if not os.path.exists(config['model_path']):
        raise ValueError("No pretrained model found, train a model and name it: [arm_convolution.model, arm_transformer.model, comlex.model, complex2.model]")
    print(config['model_type'])
    match config['model_type']:
        # ARM pretrained
        case "arm_transformer" | "arm_convolution":
            model = get_arm_model(config, number_of_entities, number_of_relations, kg_train)
            model_params = torch.load(config['model_path'])
            print("Model dimensions:")
            model.load_state_dict(model_params, strict=False)

        case "complex":
            from external_models.gekcs.models import ComplEx
            kbc_model = ComplEx(
                (number_of_entities, number_of_relations, number_of_entities),
                rank=config['embedding_dimension']
            )
            pretrained_kbc_model_path = torch.load(config['model_path'])
            kbc_model.load_state_dict(pretrained_kbc_model_path["weights"])
            model = KBCWrapper(
                kbc_model,
                config['embedding_dimension'],
                number_of_entities,
                number_of_relations
            )

        case "complex2":
            from external_models.gekcs.gekc_models import SquaredComplEx
            kbc_model = SquaredComplEx(
                (number_of_entities, number_of_relations, number_of_entities),
                rank=config['embedding_dimension']
            )
            pretrained_kbc_model_path = torch.load(config['model_path'])
            kbc_model.load_state_dict(pretrained_kbc_model_path["weights"])
            model = KBCWrapper(
                kbc_model,
                config['embedding_dimension'],
                number_of_entities,
                number_of_relations
            )

        case "nbf":
            from torchdrug import core
            from external_models.nbfnet import dataset, model, util, layer, task
            vars = {'gpus': [0]}
            cfg = util.load_config(config['config'], context=vars)
            if torch.cuda.is_available():
                cfg['gpus'] = [0]
            

            dataset = core.Configurable.load_config_dict(cfg.dataset)
          
            cfg['checkpoint'] = config['model_path']
            solver = util.build_solver(cfg, dataset)

             

            NBFWrapper = get_nbf_wrapper()  # Get the NBFWrapper class
            model = NBFWrapper(
                solver.test_set,
                solver.model,
                config['embedding_dimension'],
                number_of_entities,
                number_of_relations
            )
        case _:
            raise ValueError("Model Type Unknown")

    return model


def load_dataset(config) -> tuple[KnowledgeGraph, KnowledgeGraph, KnowledgeGraph]:
    def assert_three(result: tuple[KnowledgeGraph, KnowledgeGraph, KnowledgeGraph] | tuple[KnowledgeGraph, KnowledgeGraph]
                     ) -> tuple[KnowledgeGraph, KnowledgeGraph, KnowledgeGraph]:
        assert len(result) == 3
        return result

    dataset_class = config['dataset']['class'].lower()

    if dataset_class == "fb15k237":
        if config['model_type'] == 'nbf':
            from src.utils import load_nbf_mapping
            return assert_three(load_nbf_mapping("fb15k237", config['dataset']['path']))
        return assert_three(load_fb15k237(data_home=config['dataset']['path']))

    elif dataset_class == "wn18rr":
        if config['model_type'] == 'nbf':
            from src.utils import load_nbf_mapping
            return assert_three(load_nbf_mapping("wn18rr", config['dataset']['path']))
        return assert_three(load_wn18rr(data_home=config['dataset']['path']))


    else:
        raise Exception(f"Dataset unknown: {dataset_class}")

def get_cli_args():
    parser = argparse.ArgumentParser(
        description='Generative modelling of knowledge graphs')

    # General setup
    parser.add_argument('--batch_size', type=int, default=1024,
                        help='input batch size for training ')
    parser.add_argument('--test_batch_size', type=int, default=512,
                        help='input batch size for testing/validation ')
    parser.add_argument('--seed', type=int, default=17,
                        metavar='S', help='random seed ')
    parser.add_argument('--model_type', type=str,
                      help='Name of the model type')
    parser.add_argument('--model_name', type=str, help='Name of pretrained model (optional)')
    parser.add_argument('--resume', action='store_true', help='resume a pretrained model')
    parser.add_argument('--eval', action='store_true', help='Eval a pretrained model on validation set')
    parser.add_argument('--wandb', action='store_true',
                        help='Log results to wandb (please set up your own server)')
    parser.add_argument('--save_model', action='store_true',
                        help='save model parameters to model_path')

    parser.add_argument('--config', type=str, help='Path to config file')
    # Learning
    parser.add_argument('--max_patience', type=int, default=20, help='patience before early stopping')
    parser.add_argument('--lr_patience', type=int, default=10, help='patience before early stopping')

    parser.add_argument('--epochs', type=int, default=1000,
                        help='number of epochs to train (default: 1000)')
    parser.add_argument('--pretrain_epochs', type=int, default=100)
    parser.add_argument('--warmup_epochs', type=int, default=100,
                        help='number of epochs to train (default: 1000)')
    parser.add_argument('--lr', type=float, default=0.1,
                        help='learning rate ')
    parser.add_argument('--embedding_dimension', type=int, default=150,
                        help='The embedding dimension (1D). Default: 200')
    parser.add_argument('--dropout_prob', type=float,
                        help='Dropout for the input embeddings. Default: 0.2.')
    parser.add_argument('--factor', type=float, default=0.5)
    parser.add_argument('--prediction_smoothing', default=0.0, type=float)
    parser.add_argument('--label_smoothing', default=0.0, type=float)
    parser.add_argument('--weight_decay', type=float, default=0.01,
                        help='Decay the learning rate by this factor every epoch. Default: 0.995')

    parser.add_argument('--smooth_prior', action='store_true')
    parser.add_argument('--temperature', type=float, default=1.0)

    # Transformer
    parser.add_argument('--num_blocks', type=int,
                        help='num blocks transformer (default:1)')
    parser.add_argument('--num_heads', type=int,
                        help='num blocks transformer (default:1)')
    parser.add_argument('--num_neurons', type=int,
                        help='num neurons transformer (default:1)')

    # Convolution
    parser.add_argument('--kernel_size', type=int,
                        help='kernel size of the convolution')

    parser.add_argument('--m', type=int,
                        help='hidden dimension after convolution')

    # TODO number of conv layers

    # Experiments


    parser.add_argument('--experiment', type=str, help='Name of the model type')
    parser.add_argument('--plot', type=str, help='Name of the model type')

    args = parser.parse_args()

    with open(args.config, "r") as f:
        raw = f.read()
        template = jinja2.Template(raw)
        config_yaml = template.render()
        config = yaml.safe_load(config_yaml)
        config = EasyDict(config)

    # Override config with command-line arguments only if not already in config. Config is leading. use CLI for defaults and extra commands.
    for arg in vars(args):
        if getattr(args, arg) is not None and arg not in config:
            config[arg] = getattr(args, arg)

    # Modify the model_path assignment
    if 'model_path' not in config:
        config['model_path'] = f"./src/saved_models/{config['dataset']['class'].lower()}/{config['model_type']}.model"

    return config
