import numpy as np
import torch
import torch.nn as nn
import time
from util.time import *
from util.env import *
import pandas as pd
import torch.nn.functional as F

from util.data import *
from util.preprocess import *

import pandas as pd
import numpy as np
from torch.nn import functional


#########################  Loss function for dim4  #########################

def loss_func(out, con_index, cat_index, output_info, y_true):
    ##### divide model prediction
    y_pred_con = out[:,:,:len(con_index)]     
    y_pred_con = y_pred_con[:,0,:]              
    y_pred_cat = out[:,:,len(con_index):]           

    ##### divide ground truth
    y_true_con = y_true[:, con_index]          
    y_true_cat = y_true[:, cat_index]           
    y_true_cat = y_true_cat.reshape(y_true_cat.size(0),-1, y_pred_cat.size(1))    
    y_true_cat = y_true_cat.permute(0,2,1)
    
    ##### loss of continuous node
    loss_con = F.mse_loss(y_pred_con, y_true_con, reduction='mean')

    ##### loss of categorical node
    cate_lens = [i for i in output_info if i>1]
    loss_cat = 0
    st=0
    
    if len(cate_lens) != 0:
        for i, cate_len in enumerate(cate_lens):
            y_true_temp = y_true_cat[:, :cate_len, i]       
            y_pred_temp = y_pred_cat[:, :cate_len, i]       
            y_pred_temp_arg = torch.zeros_like(y_pred_temp) 
            y_pred_temp_arg[:,torch.argmax(y_pred_temp, dim=1)] = 1
            
            tmp = functional.cross_entropy(y_pred_temp, y_true_temp, reduction='mean')
            loss_cat += tmp

        loss_cat = loss_cat / len(cate_lens)  

    return loss_con, loss_cat

#########################  Test function  #########################
def test(model, dataloader, config, is_test = "val"):
    # test
    device = get_device()
    loss_con_list = []
    loss_cat_list = []
    now = time.time()

    t_test_predicted_list = []
    t_test_ground_list = []
    t_test_labels_list = []

    test_len = len(dataloader)
    con_index, cat_index, output_info = model.datainfo()
    model.eval()

    i = 0

    for x, y, labels, _ in dataloader:
        x, y, labels = [item.to(device).float() for item in [x, y, labels]]
        
        with torch.no_grad():
            predicted = model(x).float().to(device)
            loss_con, loss_cat = loss_func(predicted, con_index, cat_index, output_info, y)

            if is_test == "test":
                pred_res = torch.zeros_like(y)                              
                pred_res[:,con_index] = predicted[:,0,:len(con_index)]      
                cat_lens = [i for i in output_info if i>1]                  
                st=0
                for j, temp_len in enumerate(cat_lens):                                     
                    max_inds = torch.argmax(predicted[:, :temp_len, len(con_index)+j], dim=1)
                    for i,max_ind in enumerate(max_inds):                                   
                        pred_res[i,cat_index[st+max_ind]] = 1                               
                    st += max(cat_lens)
                        
                #####  return
                predicted = pred_res        
                labels = labels.unsqueeze(1).repeat(1,y.shape[1])                          

                if len(t_test_predicted_list) <= 0:
                    t_test_predicted_list = predicted 
                    t_test_ground_list = y 
                    t_test_labels_list = labels
                else:
                    t_test_predicted_list = torch.cat((t_test_predicted_list, predicted), dim=0)
                    t_test_ground_list = torch.cat((t_test_ground_list, y), dim=0)
                    t_test_labels_list = torch.cat((t_test_labels_list, labels), dim=0)
        
        loss_con_list.append(loss_con.item())
        if loss_cat != 0:
            loss_cat_list.append(loss_cat.item())
        
        i += 1
        if i % 10000 == 1 and i > 1:
            print(timeSincePlus(now, i / test_len))

    if is_test == "test":
        test_predicted_list = t_test_predicted_list.tolist()     
        test_ground_list = t_test_ground_list.tolist()             
        test_labels_list = t_test_labels_list.tolist()             
    
    
    avg_loss_con = sum(loss_con_list)/len(loss_con_list)
    if loss_cat_list == []:
        avg_loss_cat = 0
    else:
        avg_loss_cat = sum(loss_cat_list)/len(loss_cat_list)
    avg_loss = avg_loss_con, avg_loss_cat

    if is_test == "val":
        return avg_loss    
    else:
        return avg_loss, [test_predicted_list, test_ground_list, test_labels_list]
