import numpy as np
import torch
import torch.nn as nn
import time
from util.time import *
from util.env import *
from sklearn.metrics import mean_squared_error
from test import *
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import precision_score, recall_score, roc_auc_score, f1_score
from torch.utils.data import DataLoader, random_split, Subset
from torch.nn import  functional


#########################  Loss function  #########################
def loss_func(out, con_index, cat_index, output_info, y_true):
    ##### divide model prediction
    y_pred_con = out[:,:,:len(con_index)]           
    y_pred_con = y_pred_con[:,0,:]                  
    y_pred_cat = out[:,:,len(con_index):]           

    
    
    ##### divide ground truth
    y_true_con = y_true[:, con_index]               
    y_true_cat = y_true[:, cat_index]               
    y_true_cat = y_true_cat.reshape(y_true_cat.size(0),-1, y_pred_cat.size(1))    
    y_true_cat = y_true_cat.permute(0,2,1)
    

    ##### loss of continuous node
    loss_con = F.mse_loss(y_pred_con, y_true_con, reduction='mean')

    ##### loss of categorical node
    cate_lens = [i for i in output_info if i>1]
    loss_cat = 0
    st=0
    if len(cate_lens) != 0:
        for i, cate_len in enumerate(cate_lens):
            y_true_temp = y_true_cat[:, :cate_len, i]              
            y_pred_temp = y_pred_cat[:, :cate_len, i]              
            tmp = functional.cross_entropy(y_pred_temp, y_true_temp, reduction='mean')
            loss_cat += tmp

        loss_cat = loss_cat / len(cate_lens)  

    return loss_con, loss_cat

#########################  train function  #########################
def train(model = None, save_path = '', config={},  train_dataloader=None, val_dataloader=None, feature_map={}, test_dataloader=None, test_dataset=None, dataset_name='swat', train_dataset=None):

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    epoch = config['epoch']
    dataloader = train_dataloader
    device = get_device()
    
    acu_loss = 0
    min_loss = 1e+8
    i = 0
    early_stop_win = 3
    stop_improve_count = 0
    con_index, cat_index, output_info = model.datainfo()

    model.train()
    start_time = time.time()

    for i_epoch in range(epoch):
        acu_loss = 0
        acu_con_loss = 0
        acu_cat_loss = 0
        model.train()
        for x, y, _, _ in dataloader:
            x, y = [item.float().to(device) for item in [x, y]]
            optimizer.zero_grad()
            
            predicted = model(x).float().to(device)
            loss_con, loss_cat = loss_func(predicted, con_index, cat_index, output_info, y)
            loss = loss_con + loss_cat

            acu_con_loss += loss_con.item()
            if loss_cat != 0:
                acu_cat_loss += loss_cat.item()    

            loss.backward()
            optimizer.step()
            acu_loss += loss.item()
            
            i += 1
            
        # use val dataset to judge
        if val_dataloader is not None:
            val_loss = test(model, val_dataloader, config, "val")
            
            conti_loss, cate_loss = val_loss
            val_loss = conti_loss + cate_loss
            print('epoch ({} / {}) (Loss:{:.7f}, Loss_con:{:.7f}, Loss_cat:{:.7f}, ACU_loss:{:.7f}, Val_loss:{:.7f}, lr:{:.7f})'.format(i_epoch, epoch, acu_loss/len(dataloader), acu_con_loss/len(dataloader), acu_cat_loss/len(dataloader), acu_loss, val_loss, optimizer.param_groups[0]['lr'] ), flush=True)
            
            torch.save(model.state_dict(), save_path)
            
            
            if val_loss < min_loss:
                min_loss = val_loss
                stop_improve_count = 0
            else:
                stop_improve_count += 1

            if stop_improve_count >= early_stop_win:
                optimizer.param_groups[0]['lr'] /= 2
                stop_improve_count = 0
            
            if optimizer.param_groups[0]['lr'] < 0.00001:
                torch.save(model.state_dict(), save_path)
                break

        else:
            if acu_loss < min_loss :
                torch.save(model.state_dict(), save_path)
                min_loss = acu_loss

    end_time = time.time()
    print(f"{end_time - start_time:.5f} sec for this epoch")  

    return i_epoch, epoch, acu_con_loss/len(dataloader), acu_cat_loss/len(dataloader), conti_loss, cate_loss