import torch as th
import torch.nn as nn
import numpy as np
import pprint
from utils.embed import flatten
from utils.th_utils import orthogonal_init_
from collections import OrderedDict

# ------------------------ NN Modules with Gradient Projection ------------------------
class GPLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.adaptor = nn.Linear(in_dim, out_dim)
        self.register_buffer("cov", th.zeros(in_dim, in_dim))
        self.do_record = False
        
    def forward(self, K):
        if self.do_record:
            k = flatten(K).detach().clone()
            k = k - k.mean(1, keepdim=True) # centered
            self.cov.copy_(k.T @ k / k.shape[0])
        output = self.adaptor(K)
        return output
    
    # ---------------- APIs ----------------
    
    def set_rec(self):
        self.do_record = True
    
    def get_cov(self):
        self.do_record = False
        return [self.cov.clone()]

class GPNNBase(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        
    def forward(self, x):
        raise NotImplementedError
    
    def set_rec(self):
        for l in self.layers:
            if isinstance(l, GPLayer):
                l.set_rec()
                
    def get_cov(self):
        ret = []
        for l in self.layers:
            if isinstance(l, GPLayer):
                ret.extend(l.get_cov())
        return ret

class GPMLP(GPNNBase):
    def __init__(self, in_dim, out_dim, hidden_layer=3, hidden_dim=128, activ_name='relu', use_last_activ=False):
        super().__init__()
        
        layers = []
        activ_func = {
            'tanh': nn.Tanh,
            'relu': nn.ReLU,
            'leaky_relu': nn.LeakyReLU,
            'sigmoid': nn.Sigmoid,
        }[activ_name.lower()]
        h_dim = in_dim
        for l in range(hidden_layer - 1):
            layers.extend([GPLayer(h_dim, hidden_dim), activ_func()])
            h_dim = hidden_dim
        layers.append(GPLayer(h_dim, out_dim))
        
        if use_last_activ:
            layers.append(activ_func())
            
        self.layers = nn.ModuleList(layers)
        
    def forward(self, x):
        for l in self.layers:
            x = l(x)
        return x

# ------------------------ Conventional NN Modules ------------------------

class FCNet(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_layer=3, hidden_dim=128, activ_name='relu', use_last_activ=False, use_layer_norm=False):
        super().__init__()
        
        layers = []
        activ_func = {
            'tanh': nn.Tanh,
            'relu': nn.ReLU,
            'leaky_relu': nn.LeakyReLU,
            'sigmoid': nn.Sigmoid,
        }[activ_name.lower()]
        h_dim = in_dim
        if use_layer_norm:
            layers.append(nn.LayerNorm(h_dim))
        for l in range(hidden_layer - 1):
            layers.extend([nn.Linear(h_dim, hidden_dim), activ_func()])
            h_dim = hidden_dim
        layers.append(nn.Linear(h_dim, out_dim))
            
        if use_last_activ:
            layers.append(activ_func())
            
        self.layers = nn.Sequential(*layers)
            
    def forward(self, x):
        return self.layers(x)
    