import torch
import torch.nn as nn
import torch_geometric
import datetime
import numpy as np
from tqdm import tqdm
import math
from collections import OrderedDict
from copy import deepcopy

from .optimization_functions import BCE_loss, tripletLoss, MSE

from itertools import chain

import random

def compute_pnorm(parameters):
    return math.sqrt(sum([p.norm().item() ** 2 for p in parameters]))

def compute_gnorm(parameters):
    return math.sqrt(sum([p.grad.norm().item() ** 2 for p in parameters if p.grad is not None]))

def binary_ranking_regression_loop(model, loader, optimizer, device, epoch, batch_size, training = True, absolute_penalty = 1.0, relative_penalty = 0.0, ranking_margin = 0.3):
    if training:
        model.train()
    else:
        model.eval()

    batch_losses = []
    batch_rel_losses = []
    batch_abs_losses = []
    batch_sizes = []
    
    batch_acc = []
    
    for batch in loader:
        batch_data, y = batch

        batch_data = batch_data.to(device)
        y = y.to(device)
        
        node_batch = batch_data.batch
        z = batch_data.x
        pos = batch_data.pos        

        if training:
            optimizer.zero_grad()
        
        try:
            output, latent_vector = model(z.squeeze(), pos, node_batch)
        except Exception as e:
            print('failed to process batch due to error:', e)
            continue

        loss_absolute = MSE(y.squeeze(), output.squeeze()) # plain MSE loss
        
        criterion = torch.nn.MarginRankingLoss(margin=ranking_margin)
        
        #used in conjunction with negative batch sampler, where the negative immediately follows each anchor
        # notice that we treat the less negative score as being ranked "higher"
        loss_relative = criterion(output[0::2].squeeze(), output[1::2].squeeze(), torch.sign((y[0::2].squeeze() - y[1::2].squeeze()) + 1e-8).squeeze())
                
        loss = (loss_relative*relative_penalty) + (loss_absolute*absolute_penalty)
        backprop_loss = loss 
        
        if training:
            backprop_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10, norm_type=2)
        
            optimizer.step()

        # return (binary) ranking accuracies
        target_ranking = ((torch.round(y[0::2].squeeze() * 100.) / 100.) > (torch.round(y[1::2].squeeze() * 100.) / 100.)).type(torch.float)
        output_ranking = ((torch.round(output[0::2].squeeze() * 100.) / 100.) > (torch.round(output[1::2].squeeze() * 100.) / 100.)).type(torch.float)
        top_1_acc = torch.sum(output_ranking == target_ranking) / float(output_ranking.shape[0])
        
        batch_acc.append(top_1_acc.item())
        
        batch_sizes.append(y.shape[0])
        batch_losses.append(loss.item())
        
        batch_rel_losses.append(loss_relative.item())
        batch_abs_losses.append(loss_absolute.item())
        
    return batch_losses, batch_sizes, batch_abs_losses, batch_rel_losses, batch_acc


def evaluate_binary_ranking_regression_loop(model, loader, device, batch_size, dataset_size):
    model.eval()
    
    all_targets = torch.zeros(dataset_size).to(device)
    all_outputs = torch.zeros(dataset_size).to(device)
    
    start = 0
    for batch in tqdm(loader):
        batch_data, y = batch

        batch_data = batch_data.to(device)
        y = y.to(device)
        
        node_batch = batch_data.batch
        z = batch_data.x
        pos = batch_data.pos        

        with torch.no_grad():
            try:
                output, latent_vector = model(z.squeeze(), pos, node_batch)
                
                all_targets[start:start + y.squeeze().shape[0]] = y.squeeze()
                all_outputs[start:start + y.squeeze().shape[0]] = output.squeeze()
                start += y.squeeze().shape[0]
            
            except Exception as e:
                print('failed to evaluate batch due to error:', e)
                
                all_targets[start:start + y.squeeze().shape[0]] = y.squeeze()
                all_outputs[start:start + y.squeeze().shape[0]] = float('nan')
                start += y.squeeze().shape[0]
           
                continue
                
       
    return all_targets.detach().cpu().numpy(), all_outputs.detach().cpu().numpy()


def classification_loop(model, loader, optimizer, device, epoch, batch_size, training = True):
    if training:
        model.train()
    else:
        model.eval()

    batch_losses = []
    batch_sizes = []
    batch_accuracies = []
    
    for batch in loader:
        batch_data, y = batch
        y = y.type(torch.float32)
        
        batch_data = batch_data.to(device)
        y = y.to(device)
        
        node_batch = batch_data.batch
        z = batch_data.x
        pos = batch_data.pos

        if training:
            optimizer.zero_grad()
        
        try:
            output, latent_vector = model(z.squeeze(), pos, node_batch)
        except Exception as e:
            print('failed to process batch due to error:', e)
            continue
        
        loss = BCE_loss(y.squeeze(), output.squeeze())
        backprop_loss = loss
        
        acc = 1.0 - (torch.sum(torch.abs(y.squeeze().detach() - torch.round(torch.sigmoid(output.squeeze().detach())))) / y.shape[0])
        
        if training:
            backprop_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10, norm_type=2)
            
            optimizer.step()
        
        batch_sizes.append(y.shape[0])
        batch_losses.append(loss.item())
        batch_accuracies.append(acc.item())
        
    return batch_losses, batch_sizes, batch_accuracies


def evaluate_classification_loop(model, loader, device, batch_size, dataset_size):
    model.eval()
    
    all_targets = torch.zeros(dataset_size).to(device)
    all_outputs = torch.zeros(dataset_size).to(device)
    
    start = 0
    for batch in loader:
        batch_data, y = batch
        y = y.type(torch.float32)
        
        batch_data = batch_data.to(device)
        y = y.to(device)
        
        node_batch = batch_data.batch
        z = batch_data.x
        pos = batch_data.pos        

        with torch.no_grad():
            try:
                output, latent_vector = model(z.squeeze(), pos, node_batch)
                
                all_targets[start:start + y.squeeze().shape[0]] = y.squeeze()
                all_outputs[start:start + y.squeeze().shape[0]] = output.squeeze()
                start += y.squeeze().shape[0]
            
            except Exception as e:
                print('failed to evaluate batch due to error:', e)
                
                all_targets[start:start + y.squeeze().shape[0]] = y.squeeze()
                all_outputs[start:start + y.squeeze().shape[0]] = float('nan')
                start += y.squeeze().shape[0]
           
                continue
       
    return all_targets.detach().cpu().numpy(), all_outputs.detach().cpu().numpy()


def contrastive_loop(model, loader, optimizer, device, epoch, loss_function, batch_size, margin, training = True):
    if training:
        model.train()
    else:
        model.eval()

    batch_losses = []
    
    for batch_data in loader:
        batch_data = batch_data.to(device)
        
        node_batch = batch_data.batch
        z = batch_data.x
        pos = batch_data.pos
    
        if training:
            optimizer.zero_grad()
        
        try:
            latent_vector = model(z.squeeze(), pos, node_batch)
        except Exception as e:
            print('failed to process batch due to error:', e)
            continue
            
        anchor = latent_vector[0::3, :]
        positive = latent_vector[1::3, :]
        negative = latent_vector[2::3, :]

        if loss_function == 'euclidean':
            loss = tripletLoss(anchor, positive, negative, margin = margin, reduction = 'mean', distance_metric = 'euclidean')
        elif loss_function == 'euclidean-normalized':
            loss = tripletLoss(anchor, positive, negative, margin = margin, reduction = 'mean', distance_metric = 'euclidean_normalized')
        elif loss_function == 'manhattan':
            loss = tripletLoss(anchor, positive, negative, margin = margin, reduction = 'mean', distance_metric = 'manhattan')
        elif loss_function == 'cosine':
            loss = tripletLoss(anchor, positive, negative, margin = margin, reduction = 'mean', distance_metric = 'cosine')
        
        backprop_loss = loss 
        if training:
            backprop_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10, norm_type=2)

            optimizer.step()

        batch_losses.append(loss.item())
        
    return batch_losses
