import torch
import torch.nn as nn
import os

import numpy as np
import loss_f

from layers import *


class RSNN(nn.Module):
    def __init__(self, network_config, layer_config, criterion):
        super(RSNN, self).__init__()
        self.layer_config = layer_config
        self.network_config = network_config
        self.loss_mode = network_config['loss']
        self.sparsity_strength = network_config['sparsity_strength']
        self.criterion = criterion
        self.n_steps = network_config['n_steps']

        self.layers = []
        self.recurrent_layers = []
        self.init_network()

        self.mode = None
    
    def new(self):
        model_new = RSNN(self.network_config, self.layer_config, self.criterion)
        model_new.set_mode(self.get_mode())
        for i in range(len(self.layers)):
            if self.layers[i].type == "recurrent":
                model_new.layers[i].set_other_variables(self.layers[i].get_other_variables())
        model_new.init_weights()
        for x, y in zip(model_new.get_arch_parameters(), self.get_arch_parameters()):
            x.data.copy_(y.data)
        for x, y in zip(model_new.get_parameters(), self.get_parameters()):
            x.data.copy_(y.data)
        return model_new

    def init_network(self):
        params = []
        arch_params = []
        for key in self.layer_config:
            c = self.layer_config[key]
            if c['type'] == "recurrent":
                self.layers.append(Recurrent_Layer(self.network_config, c, key))
            elif c['type'] == "linear":
                self.layers.append(Feedforward_Layer(self.network_config, c, key))
            elif c['type'] == "pooling":
                self.layers.append(Pooling_Layer(self.network_config, c, key))
            elif c['type'] == "direct":
                self.layers.append(Direct_Layer(self.network_config, c, key))

    def init_weights(self):
        for l in self.layers:
            l.init_weights()

    def forward(self, psc):
        for l in self.layers:
            psc = l(psc) 
        return psc

    def model_loss(self, inputs, labels):
        n_steps = self.network_config['n_steps']
        outputs = self(inputs)

        if self.loss_mode == "count":
            # set target signal
            desired_count = self.network_config['desired_count']
            undesired_count = self.network_config['undesired_count']

            targets = torch.ones(outputs.shape[0], outputs.shape[1], 1, 1).cuda() * undesired_count
            for i in range(len(labels)):
                targets[i, labels[i], ...] = desired_count
            loss = self.criterion.spike_count(outputs, targets, self.network_config)

            out = torch.sum(outputs, dim=4).squeeze_(-1).squeeze_(-1).detach().cpu().numpy()
            predicted = np.argmax(out, axis=1).reshape(-1)
            labels = labels.cpu().numpy()
            total = len(labels)
            correct = np.array(predicted == labels).astype(int).sum().item()

        elif self.loss_mode == "softmax":
            loss = self.criterion.spike_softmax(outputs, labels, self.network_config)

            out = torch.sum(outputs, dim=4).squeeze_(-1).squeeze_(-1).detach().cpu().numpy()
            predicted = np.argmax(out, axis=1)
            labels = labels.cpu().numpy()
            total = len(labels)
            correct = np.array(predicted == labels).astype(int).sum().item()

        elif self.loss_mode == "softmax_last":
            loss = self.criterion.spike_softmax(outputs, labels, self.network_config)
            out = outputs[..., -1]

            out = out.squeeze_(-1).squeeze_(-1).detach().cpu().numpy()
            predicted = np.argmax(out, axis=1)
            labels = labels.cpu().numpy()
            total = len(labels)
            correct = np.array(predicted == labels).astype(int).sum().item()
        elif self.loss_mode == "framewise":
            shape = outputs.shape
            l_shape = labels.shape

            s_outputs = outputs.view(shape[0], shape[1], 1, 1, int(n_steps/5), 5).sum(dim=5)
            loss = self.criterion.spike_framewise(s_outputs, labels, self.network_config)

            outs = s_outputs.view(shape[0], shape[1], -1).detach().cpu().numpy()
            labels = labels.cpu().numpy()

            predicted = np.argmax(outs, axis=1)

            total = int(l_shape[0]*l_shape[1])
            correct = np.array(predicted == labels).astype(int).sum().item()

        else:
            raise Exception('Unrecognized loss function.')

        return loss, outputs, correct, total

    def model_loss_arch(self, inputs, labels):
        loss, outputs, correct, total = self.model_loss(inputs, labels)
        if self.sparsity_strength > 0:
            for l in self.layers:
                if l.type == "recurrent" and l.get_arch_parameters() is not None:
                    for p in l.get_arch_parameters()[0]:
                        loss += self.sparsity_strength * torch.sum(torch.square(p[..., 1:])) 
        return loss, outputs

    def get_parameters(self):
        params = []
        for l in self.layers:
            for p in l.get_parameters():
                params.append(p)
        self.weight_parameters = nn.ParameterList(params)
        return self.weight_parameters

    def get_arch_parameters(self):
        params = []
        for l in self.layers:
            if l.type == "recurrent" and l.get_arch_parameters():
                for p in l.get_arch_parameters():
                    params.append(p)
        self.arch_parameters = nn.ParameterList(params)
        return self.arch_parameters

    def set_mode(self, m):
        self.mode = m
        for l in self.layers:
            l.set_mode(m)

    def get_mode(self):
        return self.mode

    def record_best_arch(self):
        for l in self.layers:
            if l.type == "recurrent":
                l.record_best_arch()




        
