import torch
import numpy as np

def evaluate_coords(model, device, test_loader, criterion, MC_dropout):
    model.eval()
    epoch_loss = 0.0
    
    if MC_dropout:
        model.train()
    else:
        model.eval()
    
    all_ap_coords = list()
    all_data = list()
    all_labels = list()
    all_output_locations = list()
    all_errors = list()

    with torch.no_grad():
        for batch_idx, (ap_coords, data, labels) in enumerate(test_loader):
            ap_coords, data, labels = ap_coords.float(), data.float(), labels.float()
            ap_coords, data, labels = ap_coords.to(device), data.to(device), labels.to(device)
        
            outputs = model(ap_coords, data)
            loss = criterion(outputs, labels).mean()

            ap_coords = ap_coords.cpu().numpy()
            data = data.cpu().numpy()
            outputs = outputs.cpu().numpy()
            labels = labels.cpu().numpy()

            all_ap_coords.extend(ap_coords)
            all_data.extend(data)
            all_labels.extend(labels)
            all_output_locations.extend(outputs)
            diff = abs(outputs - labels)    # shape: (batch_size, 2)

            error = np.sqrt(np.sum(diff**2, axis=3))
            all_errors.extend(error)
            
            epoch_loss += loss.item()


    all_ap_coords_np = np.array(all_ap_coords)
    all_data_np = np.array(all_data)
    labels_np = np.array(all_labels)
    output_locations_np = np.array(all_output_locations)
    errors_np = np.array(all_errors)

    batch_num = batch_idx + 1
    epoch_loss /= batch_num

    
    # save
    res_dict = {
            "output_locations": output_locations_np,
            "labels": labels_np,
            "errors": errors_np,
            "ap_coords": all_ap_coords_np,
            "test_data": all_data_np
    }

    return epoch_loss, res_dict


def evaluate_MC_drop_out(opt_eval, model, device, test_loader, criterion):
    num_sampling = opt_eval.num_sampling
    
    all_output_locations = list()
    all_epoch_loss = 0.0
    for _ in range(num_sampling):
        epoch_loss, res_dict = evaluate_coords(model, device, test_loader, criterion, opt_eval.MC_dropout)
        coords_pred = res_dict['output_locations'].transpose(0, 2, 1, 3)   # [B, T, E, 2]
        coords_pred = coords_pred.reshape(-1, coords_pred.shape[2], coords_pred.shape[3])
        all_output_locations.append(coords_pred)
        all_epoch_loss += epoch_loss
    
    all_labels = res_dict['labels']
    all_labels = all_labels.transpose(0, 2, 1, 3) 
    all_labels = all_labels.reshape(-1, all_labels.shape[2], all_labels.shape[3])
    labels = all_labels[:,0,:]
    
    all_output_locations = np.stack(all_output_locations, axis=0)

    output_locations_mean, output_locations_std = np.mean(all_output_locations, axis=0), np.std(all_output_locations, axis=0)
    output_locations_var = output_locations_std ** 2
    print(output_locations_mean.shape)
    output_locations_fusion_mean, output_locations_fusion_var = gaussian_fusion(output_locations_mean, output_locations_var)

    output_locations_diff = abs(output_locations_mean - all_labels)
    output_locations_error = np.sqrt(np.sum(output_locations_diff ** 2, axis=2))
    
    output_locations_fusion_diff = abs(output_locations_fusion_mean - labels)
    output_locations_fusion_error = np.sqrt(np.sum(output_locations_fusion_diff**2, axis=1))

    output_locations_fusion_mean_epoch_loss = all_epoch_loss / num_sampling

    fin_res_dict = {
            "output_locations_fusion_mean": output_locations_fusion_mean,
            "output_locations_fusion_var": output_locations_fusion_var,
            "output_locations_fusion_error": output_locations_fusion_error,
            
            "output_locations_mean": output_locations_mean,
            "output_locations_var": output_locations_var,
            "output_locations_error": output_locations_error,
            "labels": res_dict['labels'],
            "eval loss": output_locations_fusion_mean_epoch_loss
    }

    return output_locations_fusion_mean_epoch_loss, fin_res_dict


def gaussian_fusion(ap_locations_mean, ap_locations_var):
    ap_point_num, ap_num, location_dim = ap_locations_mean.shape

    ap_locations_x_mean, ap_locations_y_mean = ap_locations_mean[:,:,0], ap_locations_mean[:,:,1]
    ap_locations_x_var, ap_locations_y_var = ap_locations_var[:,:,0], ap_locations_var[:,:,1]

    x_mean_fusion, y_mean_fusion = np.zeros((ap_point_num, 1)), np.zeros((ap_point_num, 1))
    x_var_fusion, y_var_fusion = np.zeros((ap_point_num, 1)), np.zeros((ap_point_num, 1))

    for point_idx in range(ap_point_num):
        temp_x_mean_fusion, temp_y_mean_fusion = 0, 0
        temp_x_var_fusion, temp_y_var_fusion = 0, 0
        for ap_idx in range(ap_num):
            curr_x_mean, curr_y_mean = ap_locations_x_mean[point_idx, ap_idx], ap_locations_y_mean[point_idx, ap_idx]
            curr_x_var, curr_y_var = ap_locations_x_var[point_idx, ap_idx], ap_locations_y_var[point_idx, ap_idx]

            curr_x_var += 1e-30
            curr_y_var += 1e-30

            covariance = np.array([ [curr_x_var, 0],
                                    [0, curr_y_var]])
            covariance_inv = np.linalg.inv(covariance)

            covariance_inv_diag = np.diagonal(covariance_inv)

            temp_x_mean_fusion += covariance_inv_diag[0]*curr_x_mean
            temp_y_mean_fusion += covariance_inv_diag[1]*curr_y_mean
            
            temp_x_var_fusion += covariance_inv_diag[0]
            temp_y_var_fusion += covariance_inv_diag[1]

        x_mean_fusion[point_idx] = temp_x_mean_fusion / temp_x_var_fusion
        y_mean_fusion[point_idx] = temp_y_mean_fusion / temp_y_var_fusion
        x_var_fusion[point_idx] = temp_x_var_fusion
        y_var_fusion[point_idx] = temp_y_var_fusion
    ap_fusion_mean = np.stack([np.squeeze(x_mean_fusion), np.squeeze(y_mean_fusion)], axis=1)
    ap_fusion_var = np.stack([np.squeeze(x_var_fusion), np.squeeze(y_var_fusion)], axis=1)

    return ap_fusion_mean, ap_fusion_var
