import torch
import torch_geometric
from torch_geometric.data import Data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_edgewise_edge_index(edge_index):
    def remove_intersection(keep, remove):
        keep_list = keep.cpu().numpy().tolist()
        remove_list = remove.cpu().numpy().tolist()
        out = set(keep_list) - set(remove_list)
        return list(out)
    # edge_index = graph.edge_index
    idx1 = []
    idx2 = []
    for i in range(len(edge_index[0, :])):
        target_node = edge_index[0, i] # i.e, we want to find the edges where this node is the taget.
        target_of_target = edge_index[1, i]
        keep_indices = (edge_index[1, :] == target_node).nonzero(as_tuple=True)[0]
        remove_indices = (edge_index[0, :] == target_of_target).nonzero(as_tuple=True)[0]
        indices = remove_intersection(keep_indices, remove_indices)
        idx1 += [i] * len(indices)
        idx2 += indices
    
    final_edge_index = torch.zeros(2, len(idx1))
    final_edge_index[1, :] = torch.tensor(idx1)
    final_edge_index[0, :] = torch.tensor(idx2)
    final_edge_index = final_edge_index.type(torch.int64)
    return final_edge_index

def get_edgewise_graph(graph, to_undirected=True):
    if to_undirected:
        undirected_edge_index = torch_geometric.utils.to_undirected(graph.edge_index)
    else:
        undirected_edge_index = graph.edge_index

    edgewise_edge_index = get_edgewise_edge_index(undirected_edge_index)
    edgewise_graph = Data(
        x=undirected_edge_index[1],
        edge_index=edgewise_edge_index.to(torch.long)
    )
    return edgewise_graph 

def get_undirected_edgewise_graph(graph, to_undirected=True):
    if to_undirected:
        undirected_edge_index_1 = torch_geometric.utils.to_undirected(graph.edge_index)
    else:
        undirected_edge_index_1 = graph.edge_index
    
    undirected_edge_index_2 = torch.zeros_like(undirected_edge_index_1)
    undirected_edge_index_2[0, :] = undirected_edge_index_1[1, :]
    undirected_edge_index_2[1, :] = undirected_edge_index_1[0, :]

    edgewise_edge_index_dir_1 = get_edgewise_edge_index(undirected_edge_index_1)
    edgewise_edge_index_dir_2 = get_edgewise_edge_index(undirected_edge_index_2)
    edgewise_edge_index = torch.cat(
        [edgewise_edge_index_dir_1, edgewise_edge_index_dir_2], dim=1)
    edgewise_graph = Data(
        x=undirected_edge_index_1[1],
        edge_index=edgewise_edge_index.to(torch.long)
    )
    return edgewise_graph 

def get_edge_initialization(data, init_type='zeros'):
    if init_type == 'zeros':
        edge_init = torch.zeros(
            data.edge_index.shape[-1], data.x.shape[-1]).to(device)
    elif init_type == 'ones':
        edge_init = torch.ones(
            data.edge_index.shape[-1], data.x.shape[-1]).to(device)
    elif init_type == 'random':
        edge_init = torch.randn_like(
            data.edge_index.shape[-1], data.x.shape[-1]).to(device)
    elif init_type == 'data_edge':
        edge_init = data.x[data.edge_index[0]]
    else:
        ValueError("Invalid Initialization Type")
    return edge_init

def get_best_val(best_val, val, relation='greater_is_better'):
    if relation == 'greater_is_better':
        if best_val < val:
            return val
        else:
            return best_val
    else:
        if best_val > val:
            return val
        else:
            return best_val

def calculate_total_model_gradient_norm(model):
    total_norm = 0
    for name, p in model.named_parameters():
        if p.grad is not None:
            param_norm = p.grad.detach().data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1/2)
    return total_norm


def log_gradients(model):
    total_norm = 0
    grad_dict = dict()
    for name, p in model.named_parameters():
        if p.grad is not None:
            param_norm = p.grad.detach().data.norm(2)
            dict_key = f"gradients/{name}"
            grad_dict[dict_key] = param_norm.item()
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1/2)
    grad_dict["Loss/total_gradient"] = total_norm
    return grad_dict

def log_statistics(model):
    model_dict = dict()
    for name, p in model.named_parameters():
        dict_key = f"model_max/{name}"
        p_max = torch.max(p).detach().item()
        model_dict[dict_key] = p_max
        dict_key = f"model_min/{name}"
        p_min = torch.min(p).detach().item()
        model_dict[dict_key] = p_min
        dict_key = f"model_mean/{name}"
        p_mean = torch.mean(p).detach().item()
        model_dict[dict_key] = p_mean
    return model_dict

def initialize_optimizer(model, args, optimizer_type='adam', transductive=False):
    if optimizer_type == 'adam':
        optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=args.model.lr,
                weight_decay=args.model.weight_decay)
    elif optimizer_type == 'sgd':
        optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=args.model.lr,
                weight_decay=args.model.weight_decay)
    else:
        raise ValueError(f"Invalid optimizer option {optimizer_type}")
    if args.lr_schedule == 'cosine':
        if transductive:
            T = args.epochs
        else:
            T = args.epochs * args.dataset.num_samples // args.dataset.batch_size
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=T)
    elif args.lr_schedule == 'reduce_lr_on_plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, patience=args.patience, factor=0.5, mode='min',
        )
    elif args.lr_schedule == 'constant':
        scheduler = None
    else:
        ValueError("Invalid Lr schedule")
    return optimizer, scheduler