import math
from collections import OrderedDict

import torch
import torch.nn as nn

from common import set_random_seed
from simple_stochastic_layers import LinearVariational, Conv2dVariational, exact_kl_divergence
from multihead_learner import ContinualMultiheadMLP


def generate_new_network(shared_structure, in_size, pre_var, device, convnet=False):
    net_struct = []
    for i, layer in enumerate(shared_structure):
        if not convnet:
            linear = LinearVariational(in_features=shared_structure[i - 1] if i > 0 else in_size,
                                       out_features=layer,
                                       prior_std=pre_var)
            linear.to(device=device)
            net_struct.append((f"layer{i}", linear))
        else:
            if i < len(shared_structure) - 1:
                cnn = Conv2dVariational(in_channels=shared_structure[i - 1] if i > 0 else in_size[0],
                                        out_channels=layer,
                                        kernel_size=5,
                                        prior_std=pre_var)
                cnn.to(device=device)
                net_struct.append((f"cnn{i}", cnn))
                maxpool = nn.MaxPool2d(kernel_size=2)
                maxpool.to(device)
                net_struct.append((f"pooling{i}", maxpool))
            else:
                flat = nn.Flatten()
                flat.to(device)
                net_struct.append(("flatten", flat))
                cnn_in = torch.rand(1, *in_size, device=device)
                output_feat = nn.Sequential(OrderedDict(net_struct))(cnn_in)
                conv_out_size = output_feat.size(1)
                linear = LinearVariational(in_features=conv_out_size,
                                           out_features=layer,
                                           prior_std=pre_var)
                linear.to(device=device)
                net_struct.append((f"layer{i}", linear))
        if not convnet or i < len(shared_structure) - 1:
            net_struct.append((f"activation{i}", nn.Tanh()))
            #net_struct.append((f"activation{i}", nn.ELU()))
        elif convnet:
            #pass
            net_struct.append((f"activation{i}", nn.Tanh()))
            #net_struct.append((f"activation{i}", nn.ELU()))
    shared_net = nn.Sequential(OrderedDict(net_struct))
    return shared_net


class ContinualMultiheadBNN(ContinualMultiheadMLP):
    def __init__(self, shared_structure, in_size, out_size, device, kl_weight, delta=0.1, n_MC=3, pre_var=1e1,
                 post_var=1e-3, use_rolling_prior=False):
        super().__init__([shared_structure[-1]], in_size, out_size, device)
        is_conv = (len(shared_structure) > 1 and shared_structure[0] != shared_structure[-1])
        self.shared_net = generate_new_network(shared_structure, in_size, pre_var, device, convnet=is_conv)
        self.prior = generate_new_network(shared_structure, in_size, pre_var, device, convnet=is_conv)
        self.prior.load_state_dict(self.shared_net.state_dict())

        self.use_rolling_prior = use_rolling_prior
        self.ctor_params = [shared_structure, in_size, pre_var, device, is_conv]

        self.delta = delta
        self.n_MC = n_MC
        self.kl_weight = kl_weight
        self.post_var = post_var

    def forward(self, task_data, task_id=None):
        return super(ContinualMultiheadBNN, self).forward(task_data, task_id)

    def adapt_new_task(self, task_id=None):
        with torch.no_grad():
            for module_name, module in self.shared_net.named_modules():
                if isinstance(module, LinearVariational) or isinstance(module, Conv2dVariational):
                    module.w_mu, module.w_p = torch.nan_to_num_(module.w_mu), torch.nan_to_num_(module.w_p)
                    if module.include_bias:
                        module.b_mu, module.b_p = torch.nan_to_num_(module.b_mu), torch.nan_to_num_(module.b_p)
        if self.use_rolling_prior:
            self.prior = generate_new_network(*self.ctor_params)
            self.prior.load_state_dict(self.shared_net.state_dict())
        super(ContinualMultiheadBNN, self).adapt_new_task(task_id)

    def loss(self, task_data, task_labels, label_loss, task_id=None, is_test=False, specific_seed=None):
        avg_empiric_loss = 0.0
        n_samples = math.sqrt(len(task_labels))
        if specific_seed is not None:
            old_seed = torch.random.seed()
            set_random_seed(specific_seed)
            result_losses = [label_loss(self.forward(task_data, task_id), task_labels) for i_mc in range(self.n_MC)]
            set_random_seed(old_seed)
            return result_losses
        for i_MC in range(self.n_MC):
            # Empirical Loss on current task:
            outputs = self.forward(task_data, task_id)
            avg_empiric_loss_curr = label_loss(outputs, task_labels)
            avg_empiric_loss += (1 / self.n_MC) * avg_empiric_loss_curr
        if is_test:
            return avg_empiric_loss
        dvrg = 0
        for module_name, module in self.shared_net.named_modules():
            if isinstance(module, LinearVariational) or isinstance(module, Conv2dVariational):
                dvrg += exact_kl_divergence(module.w_mu, module.w_p,
                                            self.prior.get_submodule(module_name).w_mu,
                                            self.prior.get_submodule(module_name).w_p, m=n_samples)
                if module.include_bias:
                    dvrg += exact_kl_divergence(module.b_mu, module.b_p,
                                                self.prior.get_submodule(module_name).b_mu,
                                                self.prior.get_submodule(module_name).b_p, m=n_samples)
        if dvrg < 0:  # should not happen
            return avg_empiric_loss
        return avg_empiric_loss + self.kl_weight * dvrg
