import torch.nn as nn
import torch
from u_base import *
from u_evaluation import evaluate
from torch.utils.data import DataLoader,TensorDataset
import torch.optim as optim
from dataloader import *

class TabularMLP(nn.Module):
    def __init__(self, input_dim=20, num_classes=3, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), init_weights=False, hidden=0):
        super(TabularMLP, self).__init__()
        self.input_dim = input_dim
        self.output_dim = num_classes
        if hidden==0:
            self.H = int(2 ** np.ceil(np.log2(np.sqrt(input_dim*num_classes))))
            if(self.H>input_dim):
                self.H = int(self.H/2)
        else:
            self.H = hidden

        # tier-2 mode
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, self.H),
            nn.ReLU(inplace=True),
            nn.Linear(self.H, num_classes)
        )

        self.device = device

        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def trainMLP(self, Xtr,Ytr,log=False):
        train_dataset = TensorDataset(
            torch.tensor(Xtr, dtype=torch.float32), 
            torch.tensor(Ytr, dtype=torch.float32))
        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

        self.to(self.device)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(self.parameters(), lr=0.001)

        for epoch in range(20):
            self.train()
            running_loss = 0.0
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                optimizer.zero_grad()
                outputs = self(inputs)

                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * inputs.size(0)
            if log:
                avg_loss = running_loss / len(len(Xtr))
                print(f"Epoch {epoch+1}: Train Loss = {avg_loss:.4f}")

    def testMLP(self, Xte, Yte):
        test_dataset = TensorDataset(
            torch.tensor(Xte, dtype=torch.float32), 
            torch.tensor(Yte, dtype=torch.float32))
        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
        self.eval()
        all_preds = []
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self(inputs)
                preds = torch.sigmoid(outputs)
                all_preds.append(preds.cpu())
        return torch.cat(all_preds).numpy()
    
if __name__ == '__main__':
    datasnames = ['20NG', 'Corel5k', 'HumanGO', 'Ohsumed', 'TMC2007_500', 'Yelp', 'Bibtex', 'Bookmarks', 'Delicious', 'Imdb']
    num_label = [20, 374, 14, 23, 22, 5, 159, 208, 983, 28]
    num_train = [12933, 3332, 2053, 9301, 19140, 7240, 4941, 58862, 10743, 81363]
    path = "data/"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    genpath = "test"
    if not os.path.exists(genpath):
        os.makedirs(genpath)

    for dataIdx in range(10):
        print(datasnames[dataIdx])
        Xtr,Ytr,Xte,Yte = readData(path+datasnames[dataIdx], num_label[dataIdx], num_train[dataIdx])
        # scaler = StandardScaler()
        # Xtr = scaler.fit_transform(Xtr)
        # Xte = scaler.transform(Xte)
        n,d = np.shape(Xtr)
        m,c = np.shape(Yte)
        print(dataIdx,"==========",n,d,m,c)

        # model = TabularMLP(input_dim=d, num_classes=c)
        # path = genpath+"/MLP_"+datasnames[dataIdx]+".pth"
        # t0 = time()
        # model.trainMLP(Xtr,Ytr)
        # t1 = time()
        # prd_MLP = model.testMLP(Xte,Yte)
        # saveResult(datasnames[dataIdx], 'MLP', evaluate(prd_MLP, Yte), t1-t0, time()-t1)
        # torch.save(model.state_dict(), path)

        t0 = time()
        prd_MLP = []
        for i in range(c):
            print(i, c)
            model = TabularMLP(input_dim=d, num_classes=1)
            model.trainMLP(Xtr,torch.tensor(Ytr[:, i]).unsqueeze(1).float())
            tmp = model.testMLP(Xte,torch.tensor(Yte[:, i]).unsqueeze(1).float()).flatten()
            prd_MLP.append(tmp)
        t1 = time()
        saveResult(datasnames[dataIdx], 'BRMLP3', evaluate(np.transpose(prd_MLP), Yte), t1-t0)
