import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

import torch.optim as optim

# generate a random orthonormal matrix
def rvs(dim=3):
     random_state = np.random
     H = np.eye(dim)
     D = np.ones((dim,))
     for n in range(1, dim):
         x = random_state.normal(size=(dim-n+1,))
         D[n-1] = np.sign(x[0])
         x[0] -= D[n-1]*np.sqrt((x*x).sum())
         # Householder transformation
         Hx = (np.eye(dim-n+1) - 2.*np.outer(x, x)/(x*x).sum())
         mat = np.eye(dim)
         mat[n-1:, n-1:] = Hx
         H = np.dot(H, mat)
         # Fix the last sign such that the determinant is 1
     D[-1] = (-1)**(1-(dim % 2))*D.prod()
     # Equivalent to np.dot(np.diag(D), H) but faster, apparently
     H = (D*H.T).T
     return H.astype('float32')

def copy_net(net1, net2):
    own_state = net1.state_dict()
    for name, param in net2.state_dict().items():
        if name not in own_state:
            continue
        # only set layers in the Module list whose indices are < n_rep
        param = param.data
        own_state[name].copy_(param)


class FCN(nn.Module):
    def __init__(self, input_dim, output_dim = 1,
                n_layers = 3, n_units = 10, activation = "RELU", no_output = False, shared_rep = None):
        super(FCN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.shared_rep = shared_rep
        if self.shared_rep == None:
            self.n_layers = n_layers
        else:
            self.n_layers = n_layers + (self.shared_rep.n_layers)
        self.n_units = n_units

        #self.input_layer = nn.Linear(self.input_dim, self.n_units)
        if self.shared_rep != None:
            self.layers = nn.ModuleList([shared_rep.net] + [nn.Linear(self.input_dim, self.n_units)]+[nn.Linear(n_units, n_units) for i in range(n_layers-1)])
        else:
            self.layers = nn.ModuleList([nn.Linear(self.input_dim, self.n_units)]+[nn.Linear(n_units, n_units) for i in range(n_layers-1)])
        self.no_output = no_output
        if not no_output:
            self.output_layer = nn.Linear(self.n_units, self.output_dim)
        self.activation = F.relu
        if activation == "RLUE":
            self.activation = F.relu
        if activation == "SIGMOID":
            self.activation = F.sigmoid
        if activation == "IDENTITY":
            self.activation = lambda x: x
    def forward(self, x):
        # x = self.activation(self.input_layer(x))
        for i, l in enumerate(self.layers):
            x = l(x)
            x = self.activation(x)
        if not self.no_output:
            x = self.output_layer(x)
        return x
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
                num_features *= s
        return num_features

# weight initialization
# takes in a module and applies the specified weight initialization
def weights_init_uniform(m):
    classname = m.__class__.__name__
    # for every Linear layer in a model..
    if classname.find('Linear') != -1:
        # apply a uniform distribution to the weights and a bias=0
        m.weight.data.uniform_(0.0, 1.0)
        m.bias.data.fill_(0)
def weights_init_normal(m, sigma = None):
    '''Takes in a module and initializes all linear layers with weight
         values taken from a normal distribution.'''
    classname = m.__class__.__name__
    # for every Linear layer in a model
    if classname.find('Linear') != -1:
        y = m.in_features
    # m.weight.data shoud be taken from a normal distribution
        if sigma == None:
            m.weight.data.normal_(0.0,1/np.sqrt(y))
        else:
            m.weight.data.normal_(0.0,sigma)
    # m.bias.data should be 0
        m.bias.data.fill_(0)

def load_my_state_dict(self, state_dict):
    own_state = self.state_dict()
    for name, param in state_dict.items():
        if name not in own_state:
            continue
        if isinstance(param, Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        own_state[name].copy_(param)

class task:
    def __init__(self, input_dim, output_dim = 1, n_layers = 10,
                n_units = 10, n_rep = 10, init = "Gaussian", no_output = False, shared_rep = None, lr = 0.001):
        self.net = FCN(input_dim, output_dim, n_layers = n_layers,
                        n_units = n_units, no_output = no_output, shared_rep = shared_rep)
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.n_layers = n_layers
        self.n_units = n_units
        if n_rep > n_layers:
            n_rep = n_layers
        self.n_rep = n_rep
        self.lr = lr
        # set optimizer
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
    def reset(self):
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
    def set_true(self):
        self.net_true = deepcopy(self.net)
    def test(self, num = 10000):
        x = torch.randn(num, self.input_dim)
        y_true = self.net_true(x)
        y_pred = self.net(x)
        return(self.criterion(y_true, y_pred))
    def gen_data(self, num = 100, sigma = 0.01):
        x = torch.randn(num, self.input_dim)
        y = self.net_true(x) + torch.randn(num, self.output_dim) * sigma
        return (x, y)
    def init_param(self, init = "Gaussian", sigma = None):
        if init == "Gaussian":
            self.net.apply(lambda m: weights_init_normal(m, sigma))
        if init == "Gaussian_standard":
            with torch.no_grad():
                self.net.apply(lambda m: weights_init_normal(m, sigma))
                modules = [x for x in self.net.modules()]
                # get the module for the output layer
                t = modules[-1].weight.clone().detach()
                modules[-1].weight = torch.nn.Parameter((t / ((t**2).sum(1).repeat(self.n_units, 1).t())**(1/2)).clone().detach())
        if init == "Uniform":
            self.net.apply(weights_init_uniform)
        if init == "Diverse_prediction":
            with torch.no_grad():
                self.net.apply(lambda m: weights_init_normal(m, sigma))
                modules = [x for x in self.net.modules()]
                # get the module for the output layer
                modules[-1].weight = torch.nn.Parameter(torch.tensor(rvs(self.n_units)[:self.output_dim, ]).clone().detach())
                t = modules[-1].weight.clone().detach()
                # normalize
                modules[-1].weight = torch.nn.Parameter((t / ((t**2).sum(1).repeat(self.n_units, 1).t())**(1/2)).clone().detach())
                modules[-1].bias.data.fill_(0)
    def add_noise(self, sigma = 0.1):
        # add noise to parameters
        with torch.no_grad():
            modules = [x for x in self.net.modules()]
            for m in modules[2:]:
                t = m.weight.clone().detach()
                noise = np.random.normal(0, size = (t.shape[0], t.shape[1]), scale = sigma).astype('float32')
                m.weight = torch.nn.Parameter(torch.tensor(t + noise).clone().detach())

    def set_representation(self, source_task, n_rep = None, use_true = False):
        if n_rep == None:
            n_rep = self.n_rep
        else:
            self.n_rep = n_rep
        if use_true:
            source_net = source_task.net_true
        else:
            source_net = source_task.net
        own_state = self.net.state_dict()
        for name, param in source_net.state_dict().items():
            if name not in own_state:
                continue
            # only set layers in the Module list whose indices are < n_rep
            if len(name.split(".")) == 3 and int(name.split(".")[1]) < n_rep:
                param = param.data
                own_state[name].copy_(param)
    def fix_representation(self, fix = False, n_rep = None):
        # fix representation for training
        if n_rep == None:
            n_rep = self.n_rep
        for name, param in self.net.state_dict().items():
            if len(name.split(".")) == 3 and int(name.split(".")[1]) < n_rep:
                param.requires_grad = fix
    def training(self, dat, mute = True, with_rep = False, epoch = 10, batch_size = 32, fix_rep = None):
        if with_rep == False:
            self.fix_representation()
        else:
            self.fix_representation(True)
        if fix_rep != None:
            if fix_rep:
                self.fix_representation(True)
            else:
                self.fix_representation(False)
        n = dat[0].shape[0]
        for e in range(epoch):
            if not mute:
                print("Starting epoch %i"%e)
            avg_loss = 0
            for i in range(int(n/batch_size)):
                dat_t = (dat[0][i*batch_size:(i+1)*batch_size, ], dat[1][i*batch_size:(i+1)*batch_size, ])
                self.optimizer.zero_grad()
                # forward + backward + optimize
                outputs = self.net(dat_t[0])
                loss = self.criterion(outputs, dat_t[1])
                loss.backward(retain_graph=True)
                self.optimizer.step()
                avg_loss += loss
            avg_loss = (avg_loss / int(n/batch_size)).detach().numpy()
            test_loss = (self.test()).detach().numpy()
            if not mute:
                print("Average training loss: %f"%avg_loss)
                print("Average testing loss: %f"%test_loss)
        return(test_loss)

    def evaluating(self, dat):
        outputs = self.net(dat[0])
        loss = self.criterion(outputs, dat[1])
        return loss
    def print_net(self, true_net = False):
        if true_net:
            for name, param in self.net_true.state_dict().items():
                print(name)
                print(param)
        else:
            for name, param in self.net.state_dict().items():
                print(name)
                print(param)


class task_group:
    def __init__(self, num_task, input_dim,
                    rep_layers = 3, pred_layers = 2,
                    n_units = 10, activation = "RELU", shared_rep = None):
        self.task_set = []
        self.num_task = num_task
        self.input_dim = input_dim
        self.output_dim = 1
        self.rep_layers = rep_layers
        self.pred_layers = pred_layers
        self.n_units = n_units
        # set a unique representation
        if shared_rep == None:
            self.representation = task(self.input_dim, self.output_dim,
                                    self.rep_layers, self.n_units, no_output = True)
            self.representation.init_param(sigma = 0.5)
            self.representation.set_true()
        else:
            self.representation = shared_rep
        # shall I use meta-learning algorithm here?
        for i in range(self.num_task):
            self.task_set.append(task(input_dim = self.input_dim, output_dim = self.output_dim,
                                    n_layers = self.pred_layers, n_units = self.n_units, n_rep = self.rep_layers,
                                    shared_rep = self.representation))
            self.task_set[-1].init_param(sigma = 0.5)
        # set the representation back to true
        copy_net(self.representation.net, self.representation.net_true)
        for i in range(self.num_task):
            self.task_set[i].set_true()

"""
-0.4454,  0.0777,  0.8480,  0.2809,  0.1326,  0.3131
t1 = task(input_dim = 3, output_dim = 2, n_layers = 3, n_units = 3)
t1.init_param()

t0 = deepcopy(t1)
t1.init_param()
t1.set_representation(t0, n_rep = 3)

# dataset
dat_0 = t0.gen_data(1000)
dat_1 = t1.gen_data(1000)

# remove parameters
t0.init_param()
t1.init_param()

t0.training(dat_0, with_rep = True)
t1.set_representation(t0, n_rep = 3)
t1.training(dat_1)
"""

