import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F
import numpy as np
from generate_T import find_invertible_submatrix, Generator_matrix
from Random_T import Random_matrix
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class EACR(nn.Module):
    def __init__(self, inport_length, embed_size, dropout_prob=0.5):
        super().__init__()
        size, new_size,l, m ,= 9, 9, 3, 3
        T = Generator_matrix(size, l, m)
        sub_T = find_invertible_submatrix(T, new_size)
        T_random_numpy = Random_matrix(new_size)
        self.T = torch.tensor(sub_T, dtype=torch.float32).to(device)
        self.T_random = torch.tensor(T_random_numpy, dtype=torch.float32).to(device)
        np.savetxt('sub_T.txt', sub_T, fmt='%d')
        np.savetxt('T_random_numpy.txt', T_random_numpy, fmt='%d')
        
        self.embedding1 = nn.Embedding(inport_length, embed_size)
        self.activate = nn.Tanh()
        
        self.pool_1= nn.MaxPool2d(kernel_size=2,stride=1)
        self.conv1_1 = nn.Conv2d(1,4,kernel_size=3,stride=1,padding=1)
        self.conv2_1= nn.Conv2d(4,8,kernel_size=3,stride=1,padding=1)
        self.conv3_1= nn.Conv2d(8,16,kernel_size=3,stride=1,padding=1)
        self.conv4_1= nn.Conv2d(16,16,kernel_size=3,stride=1,padding=1)
        
        self.pool_2= nn.MaxPool2d(kernel_size=2,stride=1)
        self.conv1_2 = nn.Conv2d(1,4,kernel_size=3,stride=1,padding=1)
        self.conv2_2= nn.Conv2d(4,8,kernel_size=3,stride=1,padding=1)
        self.conv3_2= nn.Conv2d(8,16,kernel_size=3,stride=1,padding=1)
        self.conv4_2= nn.Conv2d(16,16,kernel_size=3,stride=1,padding=1)
        
        self.pool_3= nn.MaxPool2d(kernel_size=2,stride=1)
        self.conv1_3 = nn.Conv2d(1,4,kernel_size=3,stride=1,padding=1)
        self.conv2_3= nn.Conv2d(4,8,kernel_size=3,stride=1,padding=1)
        self.conv3_3= nn.Conv2d(8,16,kernel_size=3,stride=1,padding=1)
        self.conv4_3= nn.Conv2d(16,16,kernel_size=3,stride=1,padding=1)
        self.fusion_weight = nn.Parameter(torch.tensor([0.5]), requires_grad=True)
        
        self.fc_residual = nn.Sequential(
                                nn.Linear(576, 128),
                                nn.ReLU(),
                                nn.Dropout(p=dropout_prob),
                                nn.Linear(128, 32),
                                nn.ReLU(),
                                nn.Dropout(p=dropout_prob),
                                nn.Linear(32, 2)
                                )
        self.fc = nn.Sequential(
                                nn.Linear(1200, 32),
                                nn.ReLU(),
                                nn.Dropout(p=dropout_prob),
                                nn.Linear(32, 2)
                                )
    def forward(self, x):
        E = self.embedding1(x)
        R = torch.flatten(E,1)
        R = self.fc_residual(R)
        A= torch.matmul(E, E.transpose(-2, -1))
        
        AA = torch.matmul(self.T, A)
        AA = torch.matmul(self.T, AA.transpose(-2, -1))
        
        AA_random = torch.matmul(self.T_random, A)
        AA_random = torch.matmul(self.T_random, AA_random.transpose(-2, -1))
       
        A = self.activate(A).unsqueeze(1)
        AA = self.activate(AA).unsqueeze(1)
        AA_random = self.activate(AA_random).unsqueeze(1)
        b, c, h, w = A.shape
        b_, c_, h_, w_ = AA.shape
        b_random, c_random, h_random, w_random = AA_random.shape
        
        x_tmp = F.relu(self.conv1_1(A))
        x_tmp = self.pool_1(x_tmp)
        x_tmp = F.relu(self.conv2_1(x_tmp))
        x_tmp = self.pool_1(x_tmp)
        x_tmp = F.relu(self.conv3_1(x_tmp))
        x_tmp = self.pool_1(x_tmp)
        x_tmp = F.relu(self.conv4_1(x_tmp))
        x_tmp = self.pool_1(x_tmp)

        x_tmp_ = F.relu(self.conv1_2(AA))
        x_tmp_ = self.pool_2(x_tmp_)
        x_tmp_ = F.relu(self.conv2_2(x_tmp_))
        x_tmp_ = self.pool_2(x_tmp_)
        x_tmp_ = F.relu(self.conv3_2(x_tmp_))
        x_tmp_ = self.pool_2(x_tmp_)
        x_tmp_ = F.relu(self.conv4_2(x_tmp_))
        x_tmp_ = self.pool_2(x_tmp_)
        
        x_tmp_random = F.relu(self.conv1_3(AA_random))
        x_tmp_random = self.pool_3(x_tmp_random)
        x_tmp_random = F.relu(self.conv2_3(x_tmp_random))
        x_tmp_random = self.pool_3(x_tmp_random)
        x_tmp_random = F.relu(self.conv3_3(x_tmp_random))
        x_tmp_random = self.pool_3(x_tmp_random)
        x_tmp_random = F.relu(self.conv4_3(x_tmp_random))
        x_tmp_random = self.pool_3(x_tmp_random)
        x_concat = torch.cat((x_tmp, x_tmp_,x_tmp_random), dim=1)

        C = x_concat.reshape(b, -1)
        C = self.fc(C)
        fusion_weight = torch.sigmoid(self.fusion_weight)
        out = fusion_weight * C + (1 - fusion_weight) * R
        return out