from typing import Tuple
import os

import torch # type: ignore

from difflogic import LogicLayer, GroupSum, PackBitsTensor, CompiledLogicNet

from library import configs
from library import models


def get_layer_data(model: torch.nn.Module) -> dict:
    layer_data = {}
    for i, layer in enumerate(model):
        if isinstance(layer, LogicLayer):
            if layer.implementation == 'cuda':
                layer_data[i] = {
                    'indices': layer.indices,
                    'given_x_indices_of_y_start': layer.given_x_indices_of_y_start,
                    'given_x_indices_of_y': layer.given_x_indices_of_y
                }
            else:
                layer_data[i] = {
                    'indices': layer.indices,
                }
    return layer_data


def set_layer_data(model: torch.nn.Module, layer_data: dict) -> torch.nn.Module:
    for i, layer in enumerate(model):
        if isinstance(layer, LogicLayer):
            if i in layer_data:
                layer.indices = layer_data[i]['indices']
                if layer.implementation == 'cuda':
                    layer.given_x_indices_of_y_start = layer_data[i]['given_x_indices_of_y_start']
                    layer.given_x_indices_of_y = layer_data[i]['given_x_indices_of_y']
    return model


def save_model(model: torch.nn.Module, config: configs.DifflogicConfig, model_path: str, model_name: str) -> None:
    model_state_dict = model.state_dict()

    layer_data = get_layer_data(model)

    model_config_path = os.path.join(model_path, f'{model_name}_config.json')
    model_state_dict_path = os.path.join(model_path, f'{model_name}_state_dict.pt')
    layer_data_path = os.path.join(model_path, f'{model_name}_layer_data.pt')

    config.to_json(model_config_path) 
    
    torch.save(model_state_dict, model_state_dict_path)

    torch.save(layer_data, layer_data_path)


def load_model(model_path: str, model_name: str) -> Tuple[torch.nn.Module, configs.DifflogicConfig]:
    model_config_path = os.path.join(model_path, f'{model_name}_config.json')
    model_state_dict_path = os.path.join(model_path, f'{model_name}_state_dict.pt')
    layer_data_path = os.path.join(model_path, f'{model_name}_layer_data.pt')

    config = configs.DifflogicConfig(
        data_config=configs.DataConfig(),
        train_config=configs.TrainConfig(),
        test_config=configs.TestConfig(),
        model_config=configs.ModelConfig(),
        experiment_config=configs.ExperimentConfig(),
        compilation_config=configs.CompilationConfig()
    )
    config.data_config.input_size = 784
    config.data_config.lower_bound_fixed = 5
    config.from_json(model_config_path)
    
    # model_state_dict = torch.load(model_state_dict_path, weights_only=True, map_location=config.model_config.device)
    model_state_dict = torch.load(model_state_dict_path, map_location=config.model_config.device)
    layer_data = torch.load(layer_data_path, weights_only=False)

    model = models.create_model(config)

    # model.load_state_dict(model_state_dict, strict=False)
    missing, unexpected = model.load_state_dict(model_state_dict, strict=True)

    model = set_layer_data(model, layer_data)

    return model, config