
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from Network.network_utils import get_acti,get_inplace_acti, reset_linconv, reset_parameters, count_layers

## end of normalization functions
class Network(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.num_inputs, self.num_outputs = args.num_inputs, args.num_outputs
        self.use_layer_norm = args.use_layer_norm
        self.hs = [int(h) for h in args.hidden_sizes]
        self.init_form = args.init_form
        self.model = []
        self.acti = get_inplace_acti(args.activation)
        self.activation_final = get_inplace_acti(args.activation_final)
        self.activation_final_name = args.activation_final
        self.iscuda = False # this means we have to run .cuda() to get it on the GPU
        self.device = args.gpu
    
    def to(self, device=None):
        if device is not None: self.device = device
        if type(device) != torch.device: use_device = torch.device("cpu") if device == "cpu" else (torch.device("cuda:" + str(device)) if type(device) == int else torch.device("cuda:" + str(device)))
        for m in self.model:
            if issubclass(type(m), Network): m.to(device=device)
            else: m.to(use_device)
        return self


    def cuda(self, device=None):
        super().cuda()
        self.iscuda = True
        if device is not None: self.device = device
        use_device = torch.device("cuda:" + str(device)) if type(device) == int else torch.device("cuda:" + str(device)) 
        for m in self.model:
            if issubclass(type(m), Network): m.cuda(device=device)
            elif type(m) == torch.nn.modules.container.ModuleList:
                for mv in m:
                    # no nested modulelists allowed
                    if issubclass(type(mv), Network): mv.cuda(device=device)
                    else: mv.to(use_device)
            else: m.to(use_device)
            if type(m) == torch.nn.modules.container.ModuleList: [mv.weight.data.device for mv in m if hasattr(mv, "weight")]
        return self

    def cpu(self):
        super().cpu()
        self.iscuda = False
        for m in self.model:
            if issubclass(type(m), Network): m.cpu()
            elif type(m) == torch.nn.modules.container.ModuleList: 
                for mv in m:
                    # no nested modulelists allowed
                    if issubclass(type(mv), Network): mv.cpu()
                    else: mv.to(torch.device("cpu"))
            else: m.to(torch.device("cpu"))

        return self

    def reset_network_parameters(self, n_layers=-1):
        return reset_parameters(self, self.init_form, n_layers=n_layers)


    def get_gradients(self):
        grads = []
        for param in self.parameters():
            grads.append(param.grad.data.flatten())
        return torch.cat(grads)

    def forward(self, x):
        '''
        all should have a forward function, but not all forward functions have the same signature
        '''
        return

