import numpy as np
import pandas as pd
import torch
from sklearn.metrics import mean_absolute_error,mean_squared_error,r2_score,roc_auc_score,f1_score,accuracy_score
from plmdata_repeat.features import get_atom_feature_dims, get_bond_feature_dims 

def seed_torch(seed=0):
    print("Seed", seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
def eval_rocauc(y_true, y_pred):
    '''
        compute ROC-AUC averaged across tasks
    '''

    rocauc_list = []

    for i in range(y_true.shape[1]):
        #AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            # ignore nan values
            is_labeled = y_true[:, i] == y_true[:, i]
            rocauc_list.append(roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]))

    if len(rocauc_list) == 0:
        raise RuntimeError('No positively labeled data available. Cannot compute ROC-AUC.')

    return sum(rocauc_list)/len(rocauc_list)

def training(epoch, model, loader, optimizer, device, criterion=torch.nn.L1Loss()):
    model.train()
    # if task is o2, directly use L1 loss at the beginning with weight 1e-3
    # else, gradually increase the weight of L1 loss as follows:
    thres = 50
    if epoch >= thres:
        l1_lambda = min(1.0, (epoch-thres) / thres) * 1e-3
    else:
        l1_lambda = 0
    for step, batch in enumerate(loader):
        batch = batch.to(device)

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            del batch
            pass
        else:
            pred = model(batch)
            optimizer.zero_grad()
            is_valid = ~torch.isnan(batch.y)
            loss = criterion(
                pred.to(torch.float32)[is_valid], batch.y.to(torch.float32)[is_valid]
            )
            l1_loss = sum(p.abs().sum() for p in model.parameters())
            loss += l1_lambda * l1_loss
            loss.backward()
            if step % 100 == 0:
                print(f"Epoch {epoch}, Step {step}, Loss: {loss.item()}")

            optimizer.step()
            del batch,pred,loss
            torch.cuda.empty_cache()


def validate(model,loader,device,task_type='regression'):
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(loader):
        batch = batch.to(device)
        if batch.x.shape[0] == 1:                
            pass
        else:
            with torch.no_grad():
                pred = model(batch)
            y_true.append(batch.y.detach().cpu())
            y_pred.append(pred.detach().cpu())
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    if task_type == 'classification':
        soft_y_pred = y_pred
        hard_y_pred = (torch.from_numpy(y_pred).sigmoid() > 0.5).int().numpy()
        rocauc = eval_rocauc(y_true, soft_y_pred)
        f1 = f1_score(y_true, hard_y_pred)
        acc = accuracy_score(y_true, hard_y_pred)
        return {'rocauc': rocauc, 'f1': f1, 'acc': acc}
    else:
        y_true, y_pred = y_true.reshape(-1), y_pred.reshape(-1)
        mae = mean_absolute_error(y_true,y_pred)
        rmse = mean_squared_error(y_true,y_pred, squared=False)
        r2 = r2_score(y_true, y_pred)
        return {'rmse': rmse, 'mae': mae, 'r2': r2}


def print_info(set_name, perf):
    output_str = '{}\t\t'.format(set_name)
    for metric_name in perf.keys():
        output_str += '{}: {:<10.4f} \t'.format(metric_name, perf[metric_name])
    print(output_str)

class AtomEncoder(torch.nn.Module):
    def __init__(self, emb_dim, optional_full_atom_features_dims=None):
        super(AtomEncoder, self).__init__()

        self.atom_embedding_list = torch.nn.ModuleList()
        if optional_full_atom_features_dims is not None:
            full_atom_feature_dims = optional_full_atom_features_dims
        else:
            full_atom_feature_dims = get_atom_feature_dims()

        for i, dim in enumerate(full_atom_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim, max_norm=1)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x):
        x_embedding = 0
        for i in range(x.shape[1]):
            x_embedding += self.atom_embedding_list[i](x[:,i])

        return x_embedding


class BondEncoder(torch.nn.Module):
    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()

        full_bond_feature_dims = get_bond_feature_dims()
        self.bond_embedding_list = torch.nn.ModuleList()
        for i, dim in enumerate(full_bond_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim, max_norm=1)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])

        return bond_embedding 
    
if __name__ == "__main__":
    pass
        