from collections import namedtuple
from typing import Generator

import torch
import yaml
import torch.nn as nn

from mind_the_pad.model_analysis.plot_model_parameters import conv2d_modules
from mind_the_pad.paths import experiments_mnist_folder, experiments_folder
from mind_the_pad.train_mnist.model import build_model_by_padding_size_mode, MnistPadding

Experiment = namedtuple('Experiment', ('path', 'data'))


def iter_mnist_exprmnt() -> Generator[Experiment, None, None]:
    for padding_setup_folder in experiments_mnist_folder.dirs():
        for id_experiment_folder in padding_setup_folder.dirs():
            data_path = id_experiment_folder / 'config.yaml'
            if not data_path.exists():
                print('Skipping', str(data_path), "because config.yaml doesn't exists")
                continue
            with open(data_path) as f:
                experiment_data: dict = yaml.safe_load(f)
            yield Experiment(id_experiment_folder, experiment_data)


ExperimentWithModel = namedtuple('ExperimentWithModel', ('path', 'data', 'model'))


def iter_mnist_exprmnt_with_model_loaded() -> Generator[ExperimentWithModel, None, None]:
    for exprmnt_path, exprmnt_data in iter_mnist_exprmnt():
        padding_size = exprmnt_data['padding']
        padding_mode = exprmnt_data['padding_mode']
        model = build_model_by_padding_size_mode(padding_size, padding_mode, output_classes=10)
        model_params_path = exprmnt_path / 'model.pth'
        if not model_params_path.exists():
            print(str(exprmnt_path), "doesn't have final model.pth saved. Skipping")
            continue
        model.load_state_dict(torch.load(model_params_path, 'cpu'))
        yield ExperimentWithModel(exprmnt_path, exprmnt_data, model)


def iter_emnist_exprmnt() -> Generator[Experiment, None, None]:
    emnist_experiments = experiments_folder / 'EMNIST'
    for run_folder in emnist_experiments.dirs():
        hparams_path = run_folder / 'lightning_logs' / 'version_0' / 'hparams.yaml'
        with open(hparams_path) as f:
            hparams = yaml.safe_load(f)
        yield Experiment(run_folder, hparams)


def iter_emnist_exprmnt_with_model_loaded(device=None) -> Generator[ExperimentWithModel, None, None]:
    for exprmnt_path, hparams in iter_emnist_exprmnt():
        model = MnistPadding(**hparams)
        ckpt = torch.load(exprmnt_path / 'last.ckpt', device)
        model.load_state_dict(ckpt['state_dict'])
        yield ExperimentWithModel(exprmnt_path, hparams, model)

def iter_emnist_exprmnt_with_model_loaded_same_only(device=None) -> Generator[ExperimentWithModel, None, None]:
    for exprmnt_path, hparams in iter_emnist_exprmnt():
        if hparams['padding_type'] != 'same':
            continue
        model = MnistPadding(**hparams)
        ckpt = torch.load(exprmnt_path / 'last.ckpt', device)
        model.load_state_dict(ckpt['state_dict'])
        yield ExperimentWithModel(exprmnt_path, hparams, model)




def num_conv_layers(model, skip_1x1_convs):
    counter = 0
    for m in conv2d_modules(model, skip_1x1_convs=skip_1x1_convs):
        if isinstance(m, nn.Conv2d):
            counter += 1
    return counter


def num_conv_relu_layers(model):
    counter = 0
    conv_meet = False
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            conv_meet = True
        if conv_meet and isinstance(m, nn.ReLU):
            conv_meet = False
            counter += 1
    return counter
