import torch
from transformer import Constants
from MAS import get_non_pad_mask
from tqdm import tqdm
import numpy as np
from scipy.stats import wasserstein_distance
from typing import List
import os

def eval_loglikelihood(model, dataloader, opt):

    total_tll = 0
    total_num_events = 0
    model.eval()
    num_grid = opt.num_grid
    for idx, batch in enumerate(tqdm(dataloader, mininterval=2, desc='  - (Evaluating) ', leave=False)):
        event_time, time_gap, event_type = map(lambda x: x.to(opt.device), batch)
        if event_time.shape[1] == 0:
            break
        event_time, time_gap, event_type = data_transformation(event_time, time_gap, event_type, opt)
        non_pad_mask = get_non_pad_mask(event_type)
        event_time = event_time * non_pad_mask.squeeze(-1)

        opt.num_grid = 30
        #print("t_gap_shape",time_gap.shape)
        loss, _ = model(event_type, event_time, time_gap, opt)
        # model(event_type, event_time, time_gap, opt)
        # loss = model.compute_loss_mle(event_type, event_time, time_gap, non_pad_mask)
        total_tll += -loss
        num_events = event_type.ne(Constants.PAD).sum().item()
        total_num_events += num_events
    opt.num_grid = num_grid
    model.train()
    return total_tll/total_num_events
        

def eval_accuracy(model, dataloader, opt):
    # batch = next(iter(dataloader))
    
    total_num_correct = 0
    total_num_pred = 0
    model.eval()
    for idx, batch in enumerate(tqdm(dataloader, mininterval=2, desc='  - (Evaluating) ', leave=False)):
        event_time, time_gap, event_type = map(lambda x: x.to(opt.device), batch)
        if event_time.shape[1] == 0:
            break
        event_time, time_gap, event_type = data_transformation(event_time, time_gap, event_type, opt)

        non_pad_mask = get_non_pad_mask(event_type).squeeze(-1)
        event_time = event_time * non_pad_mask
        type_pred = model.predict(event_type, event_time, time_gap, opt) * non_pad_mask
        total_num_correct += torch.sum(type_pred == event_type).item() - torch.sum(event_type== 0).item()
        total_num_pred += torch.sum(non_pad_mask).item()
    return total_num_correct/total_num_pred
    
    
def eval_MAE(model, dataloader, opt):
    # batch = next(iter(dataloader))
    
    total_num_correct = 0
    total_num_pred = 0
    model.eval()
    for idx, batch in enumerate(tqdm(dataloader, mininterval=2, desc='  - (Evaluating) ', leave=False)):
        event_time, time_gap, event_type = map(lambda x: x.to(opt.device), batch)
        if event_time.shape[1] == 0:
            break
        event_time, time_gap, event_type = data_transformation(event_time, time_gap, event_type, opt)

        non_pad_mask = get_non_pad_mask(event_type).squeeze(-1)
        event_time = event_time * non_pad_mask
        time_pred = model.predict_time(event_type, event_time, time_gap, opt) * non_pad_mask
        total_num_correct += torch.sum((time_pred[:,:-1] - time_gap).pow(2))
        total_num_pred += torch.sum(non_pad_mask).item()
    return torch.sqrt(total_num_correct/total_num_pred)


def plot_intensity(model, dataloader, opt):
    # batch = next(iter(dataloader))
    #total_num_correct = 0
    #total_num_pred = 0
    model.eval()
    batch_id=0
    for idx, batch in enumerate(tqdm(dataloader, mininterval=2, desc='  - (Evaluating) ', leave=False)):
        batch_id+=1
        event_time, time_gap, event_type = map(lambda x: x.to(opt.device), batch)
        if event_time.shape[1] == 0:
            break
        event_time, time_gap, event_type = data_transformation(event_time, time_gap, event_type, opt)

        non_pad_mask = get_non_pad_mask(event_type).squeeze(-1)
        event_time = event_time * non_pad_mask
        model.plot_intensity(event_type, event_time, time_gap, opt, batch_id)
        #total_num_correct += torch.sum((time_pred[:,:-1] - time_gap).pow(2))
        #total_num_pred += torch.sum(non_pad_mask).item()
#    return torch.sqrt(total_num_correct/total_num_pred)


def eval_MAE(model, dataloader, opt):
    # batch = next(iter(dataloader))
    total_num_correct = 0
    total_num_pred = 0
    model.eval()
    for idx, batch in enumerate(tqdm(dataloader, mininterval=2, desc='  - (Evaluating) ', leave=False)):
        event_time, time_gap, event_type = map(lambda x: x.to(opt.device), batch)
        if event_time.shape[1] == 0:
            break
        event_time, time_gap, event_type = data_transformation(event_time, time_gap, event_type, opt)

        non_pad_mask = get_non_pad_mask(event_type).squeeze(-1)
        event_time = event_time * non_pad_mask
        time_pred = model.predict_time(event_type, event_time, time_gap, opt) * non_pad_mask
        total_num_correct += torch.sum((time_pred[:,1:] - time_gap).pow(2))
        total_num_pred += torch.sum(non_pad_mask).item()
    return torch.sqrt(total_num_correct/total_num_pred)

def eval_MMD(model, dataloader, opt):
    total_num_correct = 0
    total_num_pred = 0
    for idx, batch in enumerate(tqdm(dataloader, mininterval=2, desc='  - (Evaluating) ', leave=False)):
        event_time, time_gap, event_type = map(lambda x: x.to(opt.device), batch)
        if event_time.shape[1] == 0:
            break
        event_time, time_gap, event_type = data_transformation(event_time, time_gap, event_type, opt)

        non_pad_mask = get_non_pad_mask(event_type).squeeze(-1)
        event_time = event_time * non_pad_mask
        time_pred = model.predict_time(event_type, event_time, time_gap, opt) * non_pad_mask
        total_num_correct += torch.sum((time_pred[:,1:] - time_gap).pow(2))
        total_num_pred += torch.sum(non_pad_mask).item()
    return torch.sqrt(total_num_correct/total_num_pred)

def counting_distance(x: np.ndarray, Y: np.ndarray, t_max: float):

    Y = np.minimum(Y, t_max)
    x = np.minimum(x, t_max)
    result = np.abs(x[None] - Y).sum(-1)
    return result

def gaussian_kernel(x: np.ndarray, sigma2: float = 1):
    return np.exp(-x/(2*sigma2))


def match_shapes(X: List, Y: List, t_max: float):

    max_x = max([(x < t_max).sum() for x in X])
    max_y = max([(y < t_max).sum() for y in Y])
    max_size = max(max_x, max_y)
    new_X = np.ones((len(X), max_size)) * t_max
    new_Y = np.ones((len(Y), max_size)) * t_max
    for i, x in enumerate(X):
        x = x[x < t_max]
        new_X[i, :len(x)] = x
    for i, y in enumerate(Y):
        y = y[y < t_max]
        new_Y[i, :len(y)] = y
    return new_X, new_Y


def MMD(X: List, Y: List, t_max: float, sample_size: int = None, sigma: float = None):

    # Do some shape matching
    X, Y = match_shapes(X, Y, t_max)
    
    # Sample from both distributions
    if sample_size is not None:
        X = [X[i] for i in np.random.choice(len(X), sample_size)]
        Y = [Y[i] for i in np.random.choice(len(Y), sample_size)]
    # Normalize the time
    X = X/t_max
    Y = Y/t_max
    t_max = 1
    
    x_x_d = [] 
    for i, x1 in enumerate(X):
        x_x_d.append(counting_distance(x1, X, t_max=t_max))
    x_x_d = np.concatenate(x_x_d)
    
    x_y_d = []
    for x in X:
        x_y_d.append(counting_distance(x, Y, t_max=t_max))
    x_y_d = np.concatenate(x_y_d)
            
    y_y_d = []
    for i, y1 in enumerate(Y):
        y_y_d.append(counting_distance(y1, Y, t_max=t_max))
    y_y_d = np.concatenate(y_y_d)
    
    if sigma is None:
        sigma = np.median(np.concatenate([x_x_d, x_y_d, y_y_d]))
    sigma2 = sigma**2
    E_x_x = np.mean(gaussian_kernel(x_x_d, sigma2))
    E_x_y = np.mean(gaussian_kernel(x_y_d, sigma2))
    E_y_y = np.mean(gaussian_kernel(y_y_d, sigma2))
    
    return np.sqrt(E_x_x - 2*E_x_y + E_y_y), sigma

def data_transformation(event_time, time_gap, event_type, opt):
    event_time = event_time.type(torch.float64)
    time_gap = time_gap.type(torch.float64)
    
    if opt.data_name != "exp-decay-multivariate" and opt.data_name != "half-sin_multivariate":
        event_time = (event_time - event_time[:, 0:1])[:,1:]
        event_type = event_type[:,1:]
        time_gap = time_gap[:,2:]
    else:
        time_gap = time_gap[:,1:]

    if opt.data_name in ['retweet']:
        event_time /= 100
        time_gap /= 100
    if opt.data_name in ['taobao']:
        event_time *= 1
        time_gap *= 1
    if opt.data_name in ['earthquake']:
        event_time /= 5
        time_gap /= 5
    if opt.data_name in ['hawkes1']:
        event_time /= 10
        time_gap /= 10
    if opt.data_name in ['hawkes2']:
        event_time /= 10
        time_gap /= 10
    if opt.data_name in ['nonstationary_poisson']:
        event_time /= 10
        time_gap /= 10
    if opt.data_name in ['stationary_poisson']:
        event_time /= 10
        time_gap /= 10
    if opt.data_name in ['nonstationary_renewal']:
        event_time /= 10
        time_gap /= 10
    if opt.data_name in ['self_correcting']:
        event_time /= 10
        time_gap /= 10
    if opt.data_name in ['stationary_renewal']:
        event_time /= 10
        time_gap /= 10
 
    if opt.seq_trunc and opt.train_able:
        min_length = (event_type != 0).sum(dim = 1).min().item()
        event_type[:, min_length:] = 0
        event_time[:, min_length:] = 0
        time_gap[:, min_length:] = 0
        # opt.h_type = "one_side"
    if opt.delete_outlier and opt.train_able:
        max_observed =  torch.max(event_time, axis = 1)[0]
        median = max_observed.median()
        std = max_observed.std()
        time_gap = time_gap[torch.abs(max_observed - median) < 2.5*std,:]
        event_type = event_type[torch.abs(max_observed - median) < 2.5*std,:]
        event_time = event_time[torch.abs(max_observed - median) < 2.5*std,:]
    return event_time, time_gap, event_type