import torch
import torch.nn as nn
import numpy as np
import pickle

import copyreg
import types

from synthetic import simulate_var

class MlpBlock(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MlpBlock, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        # self.dropout = nn.Dropout(p=0.2)
        # self.bn = nn.BatchNorm1d(hidden_size)
        self.gelu = nn.GELU()
        # self.gelu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # a = next(self.fc1.parameters()).device   # 看一下网络中的参数放在哪个设备上
        x = self.fc1(x)  # [序列长度]
        # x = self.dropout(x)
        # x = self.bn(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x

# 一个mixer网络
class MixerBlock(nn.Module):
    def __init__(self, channels, patches):
        super(MixerBlock, self).__init__()
        # 定义第一个mlp网络群
        networks1 = []
        for _ in range(channels):
            # input, hidden, output
            network = MlpBlock(patches, 200, patches)
            networks1.append(network)
        self.mlp_networks1 = nn.ModuleList(networks1)

        # 定义第二个mlp网络群
        networks2 = []
        for _ in range(patches):
            network = MlpBlock(channels, 200, channels)
            networks2.append(network)
        self.mlp_networks2 = nn.ModuleList(networks2)
        self.penalty_x = torch.Tensor()

        # 输出的预测mlp
        self.output = nn.Linear(channels*patches, 1)


        # self.output = nn.Sequential(
        #     nn.Linear(channels * patches, 100),
        #     nn.GELU(),
        #     nn.Linear(100, 1)
        # )
        self.output1 = nn.Linear(channels * patches, 100)
        # self.dropout = nn.Dropout(p=0.1)
        self.gelu = nn.GELU()
        self.output2 = nn.Linear(100, 1)

        return

    # x :tensor[patches,channel]
    # 控制mixer的整个过程
    def forward(self, input):
        # 一个转置操作 x:tensor[channel,patches]
        if input.dim()==2:
            x = input.transpose(0, 1)
            # 第一个mlp层 第一种实现方式
            out1 = []
            for i,fc in enumerate(self.mlp_networks1):
                out1.append(fc(x[i, :]))
            out1 = torch.stack(out1, dim=0) # (5, 10)
            # 一个转置操作 x:tensor[patches,channel]
            x = out1.transpose(0, 1)
            # skip connection
            x = input + x
            # 第二个mlp层 第一种实现方式
            out2 = []
            for i, fc in enumerate(self.mlp_networks2):
                out2.append(fc(x[i, :]))
            out2 = torch.stack(out2, dim=0)  # (5, 10)
            x = x + out2
            # 记录此时x的值，待会要拿来算惩罚
    #         self.penalty_x.append(x)
            self.penalty_x = torch.cat((self.penalty_x, x.unsqueeze(0)), dim=0)
            # 点乘操作，得出要预测的序列的下一个数
            # output = torch.flatten(torch.sum(x * input))
            output = self.output1(torch.flatten(x * input))
            # output = self.dropout(output)
            output = self.gelu(output)
            output = self.output2(output)
        elif input.dim()==3:
            x = input.transpose(1, 2)
            # 第一个mlp层 第一种实现方式
            out1 = []
            for i, fc in enumerate(self.mlp_networks1):
                out1.append(fc(x[:, i, :]))
            out1 = torch.stack(out1, dim=1)  # (5, 10)
            # 一个转置操作 x:tensor[patches,channel]
            x = out1.transpose(1, 2)
            # skip connection
            x = input + x
            # 第二个mlp层 第一种实现方式
            out2 = []
            for i, fc in enumerate(self.mlp_networks2):
                out2.append(fc(x[:, i, :]))
            out2 = torch.stack(out2, dim=1)  # (5, 10)
            x = x + out2
            # 记录此时x的值，待会要拿来算惩罚
            self.penalty_x = x
            # 点乘操作
            x = x * input
            # 1直接得出要预测的序列的下一个数
            # output = torch.sum(x, dim=(1, 2))
            # 2透过一个神经网络预测输出
            x = x.reshape(x.shape[0], -1)
            output = self.output1(x)
            # output = self.dropout(output)
            output = self.gelu(output)
            output = self.output2(output)

        return output

    def GC(self, threshold=True, ignore_lag=True):
        if ignore_lag:# 如果为true，则联合计算所有滞后的权重范数。
            # 计算第一个层次的权重矩阵的范数,要沿着第0个和第2个维度进行计算。[100,10,5]
            GC = torch.norm(self.penalty_x, dim=(0, 2))
        # else:
        #     GC = [torch.norm(net.layers[0].weight, dim=0)
        #           for net in self.networks]

        array = GC.cpu().detach().numpy()

        for a in array:
            print(a, end=' ')
        print("")

        GC = torch.norm(self.penalty_x, dim=(0))
        # GC = torch.stack(GC)
        array = GC.cpu().detach().numpy()
        np.savetxt("GC_single.txt", array, fmt='%.5f')
        if threshold:
            return (GC > 0).int()
        else:
            return GC

class MlpMixer(nn.Module):
    def __init__(self, p, lag):
        super(MlpMixer, self).__init__()
        self.p = p
        self.lag = lag
        networks = []
        # 定义p个MixerBlock网络
        for _ in range(p):
            network = MixerBlock(lag, p)
            networks.append(network)
        self.mixer_networks = nn.ModuleList(networks)

    # 获取格兰杰因果矩阵 大于threshold的表示有因果关系
    def GC(self, threshold=0, ignore_lag=True):
        if ignore_lag:# 如果为true，则联合计算所有滞后的权重范数。
            GC = []
            for mixer_network in self.mixer_networks:
                # 计算第一个层次的权重矩阵的范数,要沿着第0个和第2个维度进行计算。[100,10,5]
                weight_norm = torch.norm(mixer_network.penalty_x, dim=(0, 2))
                # 将范数添加到GC列表中
                GC.append(weight_norm)
            # 转换为张量
            GC = torch.stack(GC)

            # array = GC.cpu().detach().numpy()
            # np.savetxt("GC1.txt", array, fmt='%.2f')

            # array = (GC > 0.01).int()
            # array = array.cpu().numpy()
            # np.savetxt("GC2.txt", array, fmt='%d')
        else:
            GC = [torch.norm(mixer_network.penalty_x, dim=0)
                  for mixer_network in self.mixer_networks]
            GC = torch.stack(GC)

        if threshold!=0:
            return (GC > threshold).int()
        else:
            return GC

    def forward(self, input):

        return input

# 这段代码的作用是对神经网络的所有非第一层权重应用岭回归正则化惩罚。该函数接收一个MLP网络和一个正则化参数lam。
# 函数首先遍历整个网络的所有非第一层，然后计算每个层的权重矩阵平方和，并将它们相加。最后，将结果乘以 lam，得到岭回归正则化项。
# 最终，函数返回神经网络的岭回归正则化项。
def ridge_regularize(network, lam):
    '''
    Apply ridge penalty to all subsequent layers of the network.
    Args:
        network: 目标网络
        lam: 正则化参数
    Returns:
        岭回归正则化项。
    '''
    # 计算岭回归正则化项
    regularization_term = 0
    mlp_block_list1 = network.mlp_networks1
    mlp_block_list2 = network.mlp_networks2
    # 第一层channels个mlp网络
    for mlp_block in mlp_block_list1:
        # 每个mlp网络里的两次全连接
        regularization_term += torch.sum(mlp_block.fc1.weight ** 2)
        regularization_term += torch.sum(mlp_block.fc2.weight ** 2)
    # 第二层patches个mlp网络
    for mlp_block in mlp_block_list2:
        regularization_term += torch.sum(mlp_block.fc1.weight ** 2)
        regularization_term += torch.sum(mlp_block.fc2.weight ** 2)
    # 输出层的mlp网络
    regularization_term += torch.sum(network.output1.weight ** 2)
    regularization_term += torch.sum(network.output2.weight ** 2)
    return lam * regularization_term

# 计算神经网络第一层权重矩阵的正则化项。该函数接收一个MLP网络，一个正则化惩罚类型penalty，和一个正则化参数lam
def regularize(network, lam, penalty):
    x = network.penalty_x
    t, p, lag = x.shape

    if penalty == 'GL':
        # total_norm = torch.sum(torch.norm(x, dim=(1)))  # 10 3
        total_norm = torch.sum(torch.norm(network.penalty_x, dim=(0, 2))) # 1000 10 3
        return lam * total_norm
    elif penalty == 'GSGL':
        return lam * (torch.sum(torch.norm(network.penalty_x, dim=(0, 2)))
                      + torch.sum(torch.norm(network.penalty_x, dim=0)))
    elif penalty == 'H':
        # Lowest indices along third axis touch most lagged values.
        total_norm = 0
        for i in range(lag):#！！！！！！！！！！！！！！！！！！！
            # lag_x = x[:, :(i + 1)]
            lag_x = x[:, :, :(i + 1)]
            lag_norm = torch.norm(lag_x, dim=(0, 2))
            total_norm += torch.sum(lag_norm)
    # elif penalty == 'H2':
    #     # Lowest indices along third axis touch most lagged values.
    #     total_norm = 0
    #     for i in range(lag):  # ！！！！！！！！！！！！！！！！！！！
    #         # lag_x = x[:, :(i + 1)]
    #         lag_x = x[:, :, :(i + 1)]
    #         lag_norm = torch.norm(lag_x, p=2, dim=(0, 2))
    #         # lag_norm = torch.norm(lag_x, p=1, dim=(0, 2))
    #         total_norm += torch.sum(lag_norm)
        return lam * total_norm

# 计算一个网络中非零权重的百分比,传入的是一个mixerblock
def nonzero_weight_ratio(model):
    total_nonzero = 0
    total_params = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            nonzero_mask = abs(param) > 0.00001
            total_nonzero += torch.count_nonzero(nonzero_mask)
            total_params += param.numel()
    print("网络总参数：{0}，非零参数：{1}，比例：{2}".format(total_params, total_nonzero, total_nonzero / total_params))
    return total_nonzero / total_params


class MixerBlockSparse(nn.Module):
    def __init__(self, channels, patches, mask):
        super(MixerBlockSparse, self).__init__()
        # 定义第一个mlp网络群
        networks1 = []
        for _ in range(channels):
            # input, hidden, output
            network = MlpBlock(patches, 100, patches)
            networks1.append(network)
        self.mlp_networks1 = nn.ModuleList(networks1)

        # 定义第二个mlp网络群
        networks2 = []
        for _ in range(patches):
            network = MlpBlock(channels, 100, channels)
            networks2.append(network)
        self.mlp_networks2 = nn.ModuleList(networks2)
        self.penalty_x = torch.Tensor()

        # 输出的预测mlp
        self.output = nn.Linear(channels*patches, 1)


        # self.output = nn.Sequential(
        #     nn.Linear(channels * patches, 100),
        #     nn.GELU(),
        #     nn.Linear(100, 1)
        # )
        self.output1 = nn.Linear(channels * patches, 100)
        # self.dropout = nn.Dropout(p=0.1)
        self.gelu = nn.GELU()
        self.output2 = nn.Linear(100, 1)
        self.mask = torch.tensor(mask).unsqueeze(1).expand(10, 5)
        return

    # x :tensor[patches,channel]
    # 控制mixer的整个过程
    def forward(self, input):
        # 一个转置操作 x:tensor[channel,patches]
        x = input.transpose(0, 1)
        # 第一个mlp层 第一种实现方式
        out1 = []
        for i,fc in enumerate(self.mlp_networks1):
            out1.append(fc(x[i, :]))
        out1 = torch.stack(out1, dim=0) # (5, 10)

        # 一个转置操作 x:tensor[patches,channel]
        x = out1.transpose(0, 1)

        # skip connection
        x = input + x

        # 第二个mlp层 第一种实现方式
        out2 = []
        for i, fc in enumerate(self.mlp_networks2):
            out2.append(fc(x[i, :]))
        out2 = torch.stack(out2, dim=0)  # (5, 10)

        x = x + out2
        # 记录此时x的值，待会要拿来算惩罚
#         self.penalty_x.append(x)
#         x = x * self.mask
        self.penalty_x = torch.cat((self.penalty_x, x.unsqueeze(0)), dim=0)
        # 点乘操作，得出要预测的序列的下一个数
        output = torch.flatten(torch.sum(x * input))
        # output = self.output1(torch.flatten(x * input))
        # output = self.gelu(output)
        # output = self.output2(output)
        return output
    def GC(self, threshold=True, ignore_lag=True):
        if ignore_lag:# 如果为true，则联合计算所有滞后的权重范数。
            # 计算第一个层次的权重矩阵的范数,要沿着第0个和第2个维度进行计算。[100,10,5]
            GC = torch.norm(self.penalty_x, dim=(0, 2))
        # else:
        #     GC = [torch.norm(net.layers[0].weight, dim=0)
        #           for net in self.networks]

        array = GC.cpu().detach().numpy()

        for a in array:
            print(a, end=' ')
        print("")
        GC = torch.norm(self.penalty_x, dim=(0))
        # GC = torch.stack(GC)
        array = GC.cpu().detach().numpy()
        np.savetxt("GC_single_parse.txt", array, fmt='%.5f')
        if threshold:
            return (GC > 0).int()
        else:
            return GC