import os
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from torchvision.models import vgg16

from .cnn_synthetic_model import (
    CnnAccumulator,
    CNNColorDetector,
    CnnMultiColorAccumulator,
    PartialNonUniformCnnAccumulator,
    PartialNonUniformCnnMultiColorAccumulator,
)
from .mlp_synthetic_model import DecisionHead, IdentityMLP, ModuloModel, SyntheticModel


def load_decision_head(**kwarg):
    decision_head_type = kwarg.pop('decision_head_type')
    if decision_head_type == 'learned_modulo':
        decision_head = load_learned_modulo(**kwarg)
    elif decision_head_type == 'synthetic_modulo':
        decision_head = load_synthetic_modulo(**kwarg)
    elif decision_head_type == 'identity_mlp_layer':
        decision_head = load_identity_mlp(**kwarg)
    else:
        raise NotImplementedError
    return decision_head


def load_learned_modulo(
        rank_increase, rank_to_increase, rank_increase_layer, num_hidden_layers, num_classes, decision_head_model_path):
    decision_head = DecisionHead(rank_increase, rank_to_increase, rank_increase_layer, num_hidden_layers, num_classes)
    decision_head.load_model_parameters(decision_head_model_path)
    return decision_head


def load_trained_vgg(ckpt_path: str, num_classes: int, device: Union[torch.device, str] = 'cuda:0') -> nn.Module:
    """
    load trained vgg model from checkpoint
    """
    assert ckpt_path.endswith('.ckpt'), 'model path has to end with .ckpt'
    assert os.path.exists(ckpt_path), 'model path does not exist'
    checkpoint = torch.load(ckpt_path)
    model_checkpoint = checkpoint["state_dict"]
    model = vgg16(pretrained=False, num_classes=num_classes).to(device).eval()
    adjusted_model_checkpoint = OrderedDict()
    for key in model_checkpoint.keys():
        adjusted_model_checkpoint[".".join(key.split(".")[1:])] = model_checkpoint[key]
    model.load_state_dict(adjusted_model_checkpoint)
    return model


def load_synthetic_modulo(modulo_number, max_number):
    return ModuloModel(modulo_number, max_number)


def load_identity_mlp(input_shape, division_scale: int = 1):
    return IdentityMLP(input_shape, division_scale)


def load_synthetic_model(
    decision_head: Dict,
    color_list: Optional[List[List[int]]] = None,
    redundant_channels: int = 0,
    background_pixel: Tuple[int] = (0, 0, 0),
    weight_init_scheme: str = 'uniform',
    inv_variance: int = 5,
    random_expand_to: int = 3,
    input_channels: int = 1,
    softmax: bool = False,
) -> nn.Module:
    """ helper function to load synthetic model in one function """
    assert redundant_channels >= 0, 'number of redundant channels has to be non negative'
    multi_color = False
    if color_list is None:
        color_detector = None
    else:
        color_detector = CNNColorDetector(
            color_list, redundant_channels=redundant_channels, background_pixel=background_pixel)
        if len(color_list) > 1:
            multi_color = True
    if not multi_color:
        if weight_init_scheme == 'uniform':
            # TODO: it is meaningless to pass the parameter weight_init_scheme to CnnAccumulator,
            #  as it only supports uniform weight init. If 'non_uniform' is applied, then another
            #  class PartialNonUniformCnnAccumulator is built, and it does not accept parameter 'weight_init_scheme'.
            #  Better remove the 'weight_init_scheme' in the CnnAccumulator.
            accumulator = CnnAccumulator(input_channels=input_channels, weight_init_scheme=weight_init_scheme)
        elif weight_init_scheme == 'non_uniform':
            accumulator = PartialNonUniformCnnAccumulator(
                input_channels=input_channels, random_expand_to=random_expand_to)
        else:
            raise NotImplementedError
    else:
        if weight_init_scheme == 'uniform':
            # TODO: similar problems as above. CnnMultiColorAccumulator only accepts weight_init_scheme = uniform or
            #  weight_init_scheme = uniform_with_bias. If the weight_init_scheme = uniform_with_bias, then it builts
            #  another class PartialNonUniformCnnMultiColorAccumulator, which does not accept parameter weight_init_scheme.
            #  So CnnMultiColorAccumulator with uniform_with_bias will never be built here.
            #  Better re-design the APIs.
            accumulator = CnnMultiColorAccumulator(
                num_colors=len(color_list),
                redundant_channels=redundant_channels,
                weight_init_scheme=weight_init_scheme)
        elif weight_init_scheme == 'non_uniform':
            accumulator = PartialNonUniformCnnMultiColorAccumulator(
                num_colors=len(color_list), redundant_channels=redundant_channels, inv_variance=inv_variance, random_expand_to=random_expand_to)
        else:
            raise NotImplementedError
    decision_head = load_decision_head(**decision_head)
    synthetic_model = SyntheticModel(accumulator, decision_head, color_detector=color_detector, softmax=softmax)
    return synthetic_model
