import os
import uuid
import math

import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.utils.data
import math
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import label_binarize

class RoGraTrainer:
    def __init__(self, device, logger):
        self.device = device
        self.logger = logger
        self.tmp_dir = 'tmp'
        if not os.path.exists(self.tmp_dir):
            os.makedirs(self.tmp_dir)

    def fit(self, model, X, y, train_idx, distances, args, test_idx=None, report_test=False):
        self.K = args.K
        self.alpha = args.alpha
        batch_size = args.bs

        train_idx, test_idx = train_test_split(np.arange(X.shape[0]), test_size=0.2, random_state=100)

        train_batch_size = min(batch_size//2, len(train_idx))
        other_idx = np.array([i for i in range(X.shape[0]) if i not in train_idx])
        other_batch_size = min(batch_size - train_batch_size, len(other_idx))
        train_dataset = Dataset(train_idx)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=1)

        if report_test:
            test_batch_size = min(batch_size//2, len(test_idx))
            other_idx_test = np.array([i for i in range(X.shape[0]) if i not in test_idx])
            other_batch_size_test = min(batch_size - test_batch_size, len(other_idx_test))
            test_dataset = Dataset(test_idx)
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=1)


        self.adj = torch.from_numpy(distances.astype(np.float32))

        self.X, self.y = torch.from_numpy(X), torch.from_numpy(y)
        file_path = os.path.join(self.tmp_dir, str(uuid.uuid4()))

        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

        best_acc = 0.0

        # Training
        for epoch in range(args.epochs):
            model.train()
            optimizer.zero_grad()

            for sampled_train_idx in train_loader:
                sampled_other_idx = np.random.choice(other_idx, other_batch_size, replace=False)
                idx = sampled_train_idx
                #idx = np.concatenate((sampled_train_idx, sampled_other_idx))

                _X   = self.X[idx].to(self.device)
                _y   = self.y[sampled_train_idx.long()].to(self.device)
                _adj = self.adj[idx][:,idx]

                outputs = model(_X, _adj, self.K, self.alpha)
                loss = F.nll_loss(outputs[:len(sampled_train_idx)], _y)

                loss.backward()
                optimizer.step()

            model.eval()
            if args.dataset == 'bcicha':
                acc = compute_auc(model, self.X, self.y, self.adj, self.K, self.alpha, train_loader, self.device, other_idx, other_batch_size)
            else:
                acc = compute_accuracy(model, self.X, self.y, self.adj, self.K, self.alpha, train_loader, self.device, other_idx, other_batch_size)
            if acc >= best_acc:
                best_acc = acc
                torch.save(model.state_dict(), file_path)
            if report_test:
                if args.dataset == 'bcicha':
                    test_acc = compute_auc(model, self.X, self.y, self.adj, self.K, self.alpha, test_loader,
                                                self.device, other_idx_test, other_batch_size_test)
                else:
                    test_acc = compute_accuracy(model, self.X, self.y, self.adj, self.K, self.alpha, test_loader, self.device, other_idx_test, other_batch_size_test)
                self.logger.log('--> Epoch {}: loss {:5.4f}; accuracy: {:5.4f}; best accuracy: {:5.4f}; test accuracy: {:5.4f}'.format(epoch, loss.item(), acc, best_acc, test_acc))
            else:
                self.logger.log('--> Epoch {}: loss {:5.4f}; accuracy: {:5.4f}; best accuracy: {:5.4f}'.format(epoch, loss.item(), acc, best_acc))
        
        # Load the best model
        model.load_state_dict(torch.load(file_path))
        model.eval()
        os.remove(file_path)

        return model
    
    def test(self, model, test_idx, batch_size=128):
        test_batch_size = 1
        #test_batch_size = min(batch_size//2, len(test_idx))
        other_idx_test = np.array([i for i in range(self.X.shape[0]) if i not in test_idx])
        other_batch_size_test = min(batch_size - test_batch_size, len(other_idx_test))
        test_dataset = Dataset(test_idx)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=1)
        acc = compute_accuracy(model, self.X, self.y, self.adj, self.K, self.alpha, test_loader, self.device, other_idx_test, other_batch_size_test)
        return acc.item()





def compute_auc(model, X, y, adj, K, alpha, loader, device, other_idx, other_batch_size): # AUC last dataset
   y_true = []
   y_scores = []
    
   with torch.no_grad():
       for batch_idx in loader:
           sampled_other_idx = np.random.choice(other_idx, other_batch_size, replace=False)
           idx = np.concatenate((batch_idx, sampled_other_idx))
           
           _X, _y, _adj = X[idx].to(device), y[idx][:len(batch_idx)].to(device), adj[idx][:, idx]
           outputs = model(_X, _adj, K, alpha)
           y_true.extend(_y.cpu().numpy())  # Ground truth labels
           y_scores.extend(outputs[:len(batch_idx)].softmax(dim=1).cpu().numpy()[:, 1])  # Probabilities for the positive class
   y_true = np.array(y_true)
   y_scores = np.array(y_scores)
   auc = roc_auc_score(y_true, y_scores)

   return auc



    
def compute_accuracy(model, X, y, adj, K, alpha, loader, device, other_idx, other_batch_size):
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx in loader:
            sampled_other_idx = np.random.choice(other_idx, other_batch_size, replace=False)
            idx = np.concatenate((batch_idx, sampled_other_idx))
            
            _X, _y, _adj = X[idx].to(device), y[idx][:len(batch_idx)].to(device), adj[idx][:,idx]
            outputs = model(_X, _adj, K, alpha)
            # print(outputs)
            preds = outputs[:len(batch_idx)].max(1)[1].type_as(_y)
            # print(preds)
            _correct = preds.eq(_y).double()
            # print(_correct)
            correct += _correct.sum()
            total += len(batch_idx)
    acc = correct / total
    
    return acc
    


class RoGra(nn.Module):
    def __init__(self, input_size, nb_classes, num_layers=2, n_feature_maps=64, dropout=0.5):
        super(RoGra, self).__init__()
        self.num_layers = num_layers

        self.block_1 = InceptionBlock(input_size, n_feature_maps)
        self.block_2 = InceptionBlock(n_feature_maps, n_feature_maps)
        self.block_3 = InceptionBlock(n_feature_maps, n_feature_maps)

        if self.num_layers == 1:
            self.gc1 = GraphConvolution(n_feature_maps, nb_classes)
        elif self.num_layers == 2:
            self.gc1 = GraphConvolution(n_feature_maps, n_feature_maps)
            self.gc2 = GraphConvolution(n_feature_maps, nb_classes)
            self.dropout = dropout
        elif self.num_layers == 3:
            self.gc1 = GraphConvolution(n_feature_maps, n_feature_maps)
            self.gc2 = GraphConvolution(n_feature_maps, n_feature_maps)
            self.gc3 = GraphConvolution(n_feature_maps, nb_classes)
            self.dropout = dropout

    def forward(self, x, adj, K, alpha):
        ranks = torch.argsort(adj, dim=1)
        sparse_index = [[], []]
        sparse_value = []
        for i in range(len(adj)):
            _sparse_value = []
            for j in ranks[i][:K]:
                sparse_index[0].append(i)
                sparse_index[1].append(j)
                _sparse_value.append(1/np.exp(alpha*adj[i][j]))
            _sparse_value = np.array(_sparse_value)
            _sparse_value /= _sparse_value.sum()
            sparse_value.extend(_sparse_value.tolist())
        sparse_index = torch.LongTensor(sparse_index)
        sparse_value = torch.FloatTensor(sparse_value)
        adj = torch.sparse.FloatTensor(sparse_index, sparse_value, adj.size())
        device = self.gc1.bias.device
        adj = adj.to(device)

        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = F.avg_pool1d(x, x.shape[-1]).squeeze()

        if self.num_layers == 1:
            x = self.gc1(x, adj)
        elif self.num_layers == 2:
            x = F.relu(self.gc1(x, adj))
            x = F.dropout(x, self.dropout, training=self.training)
            x = self.gc2(x, adj)
        elif self.num_layers == 3:
            x = F.relu(self.gc1(x, adj))
            x = F.dropout(x, self.dropout, training=self.training)
            x = F.relu(self.gc2(x, adj))
            x = F.dropout(x, self.dropout, training=self.training)
            x = self.gc3(x, adj)

        x = F.log_softmax(x, dim=1)

        return x


import torch
import torch.nn as nn

class GraphConvolution(nn.Module): # Residual GCN 
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        
        self.W_message = nn.Parameter(torch.FloatTensor(in_features, out_features))
        
        
        if in_features != out_features:
            self.W_residual = nn.Linear(in_features, out_features, bias=False)
        else:
            self.W_residual = None
        
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / (self.W_message.size(0) ** 0.5)
        self.W_message.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        message = torch.mm(input, self.W_message)
        output = torch.spmm(adj, message)
        
        # Add bias 
        if self.bias is not None:
            output = output + self.bias

        
        if self.W_residual is not None:
            input = self.W_residual(input)
        
        
        return input + output

class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(InceptionBlock, self).__init__()
        
        self.branch1 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels // 4, kernel_size=1),
            nn.BatchNorm1d(out_channels // 4),
            nn.ReLU()
        )

        self.branch2 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels // 4, kernel_size=1),
            nn.BatchNorm1d(out_channels // 4),
            nn.ReLU(),
            nn.Conv1d(out_channels // 4, out_channels // 4, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels // 4),
            nn.ReLU()
        )

        self.branch3 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels // 4, kernel_size=1),
            nn.BatchNorm1d(out_channels // 4),
            nn.ReLU(),
            nn.Conv1d(out_channels // 4, out_channels // 4, kernel_size=5, padding=2),
            nn.BatchNorm1d(out_channels // 4),
            nn.ReLU()
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool1d(kernel_size=3, stride=1, padding=1),
            nn.Conv1d(in_channels, out_channels // 4, kernel_size=1),
            nn.BatchNorm1d(out_channels // 4),
            nn.ReLU()
        )
        
        # Shortcut for expanding channels if needed
        self.expand = in_channels != out_channels
        if self.expand:
            self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1)
            self.bn_shortcut = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        branch1_out = self.branch1(x)
        branch2_out = self.branch2(x)
        branch3_out = self.branch3(x)
        branch4_out = self.branch4(x)

        out = torch.cat([branch1_out, branch2_out, branch3_out, branch4_out], dim=1)

        if self.expand:
            x = self.shortcut(x)
            x = self.bn_shortcut(x)

        out += x
        return F.relu(out)

class Dataset(torch.utils.data.Dataset):
    def __init__(self, idx):
        self.idx = idx

    def __getitem__(self, index):
        return self.idx[index]

    def __len__(self):
        return len(self.idx)

