import numpy as np
import os
import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from net.net_avt import Net, NetTransformer
from dataloader.dataset_avt import MultimodalDataset, MultimodalValDataset
from utils import (random_feature_A_masking, random_feature_V_masking, random_feature_T_masking,
                   kd_loss, std_base_loss, test_logging, collate_fn, collate_fn_val)


if __name__ == '__main__':
    STAGE = 'train_KD_avt_0402'
    log_dir = f'log/{STAGE}'
    EPOCHS = 10000
    BATCHSIZE = 16
    TEST_BATCHSIZE = 16
    TLEN = 600  # number of video frames with 30 fps
    BACKPROP_STEP = 78
    LOG_STEP = 78
    VAL_STEP = 10
    train_split = 'Training'
    val_split = 'Development'
    device = torch.device('cuda:0')
    writer = SummaryWriter(log_dir=log_dir)

    mkdir_list = ['weights/{}'.format(STAGE)]
    for path in mkdir_list:
        if not os.path.exists(path):
            os.makedirs(path)

    train_dataset = MultimodalDataset('dataset/data.csv', train_split, TLEN)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=BATCHSIZE,
                                  shuffle=True,
                                  num_workers=16,
                                  collate_fn=collate_fn,
                                  pin_memory=True,
                                  drop_last=True)
    test_dataset = MultimodalValDataset('dataset/data.csv', val_split, TLEN)
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=TEST_BATCHSIZE,
                                 shuffle=False,
                                 num_workers=16,
                                 collate_fn=collate_fn_val,
                                 pin_memory=True,
                                 drop_last=False)

    model_T = NetTransformer()
    model_T = torch.nn.DataParallel(model_T)
    model_T = model_T.to(device)
    model_T.load_state_dict(torch.load('weights/train_single_avt_0401/model_G_300.pth', map_location=device))
    for name, param in model_T.named_parameters():
        param.requires_grad = False

    model_S = NetTransformer()
    model_S = torch.nn.DataParallel(model_S)
    model_S = model_S.to(device)
    model_S.load_state_dict(torch.load('weights/train_single_avt_0401/model_G_300.pth', map_location=device))

    optimizer_S = Adam(model_S.parameters(), lr=0.0002)
    optimizer_S.zero_grad()

    mse_loss_func = nn.MSELoss()
    mae_loss_func = nn.L1Loss()

    train_flags = 1
    val_flags = 1
    for epoch in range(EPOCHS):
        RMSE_loss_t = []
        MAE_loss_t = []
        REC_loss_t = []
        CON_loss_t = []
        ORTH_loss_t = []
        KL_loss_t = []
        RMSE_loss_s = []
        MAE_loss_s = []
        REC_loss_s = []
        CON_loss_s = []
        ORTH_loss_s = []
        KL_loss_s = []
        KD_loss = []
        model_T.eval()
        model_S.train()
        for step, (train_data_v, train_data_a, train_data_t, train_label, _, _) in enumerate(train_dataloader):
            train_data_v = train_data_v.to(device)
            train_data_a = train_data_a.to(device)
            train_data_t = train_data_t.to(device)
            train_label = train_label.to(device)

            out_t = model_T(train_data_v, train_data_a, train_data_t)
            out_s = model_S(random_feature_V_masking(train_data_v),
                            random_feature_A_masking(train_data_a),
                            random_feature_T_masking(train_data_t))

            mse_t, mae_t, rec_t, con_t, orth_t, kl_t = std_base_loss(out_t, train_label)
            mse_s, mae_s, rec_s, con_s, orth_s, kl_s = std_base_loss(out_s, train_label)
            kd = (kd_loss(out_t['score'], train_label, out_s['x_v_homo'], out_t['x_v_homo'])
                  + kd_loss(out_t['score'], train_label, out_s['x_a_homo'], out_t['x_a_homo'])
                  + kd_loss(out_t['score'], train_label, out_s['x_t_homo'], out_t['x_t_homo'])
                  + kd_loss(out_t['score'], train_label, out_s['x_v_hete'], out_t['x_v_hete'])
                  + kd_loss(out_t['score'], train_label, out_s['x_a_hete'], out_t['x_a_hete'])
                  + kd_loss(out_t['score'], train_label, out_s['x_t_hete'], out_t['x_t_hete'])
                  + kd_loss(out_t['score'], train_label, out_s['x_v_noise'], out_t['x_v_noise'])
                  + kd_loss(out_t['score'], train_label, out_s['x_a_noise'], out_t['x_a_noise'])
                  + kd_loss(out_t['score'], train_label, out_s['x_t_noise'], out_t['x_t_noise'])
                  + kd_loss(out_t['score'], train_label, out_s['latents'], out_t['latents'])
                  + kd_loss(out_t['score'], train_label, out_s['x_v_homo_dec'], out_t['x_v_homo_dec'])
                  + kd_loss(out_t['score'], train_label, out_s['x_a_homo_dec'], out_t['x_a_homo_dec'])
                  + kd_loss(out_t['score'], train_label, out_s['x_t_homo_dec'], out_t['x_t_homo_dec'])
                  + kd_loss(out_t['score'], train_label, out_s['x_v_hete_dec'], out_t['x_v_hete_dec'])
                  + kd_loss(out_t['score'], train_label, out_s['x_a_hete_dec'], out_t['x_a_hete_dec'])
                  + kd_loss(out_t['score'], train_label, out_s['x_t_hete_dec'], out_t['x_t_hete_dec'])
                  + kd_loss(out_t['score'], train_label, out_s['x_v_n_dec'], out_t['x_v_n_dec'])
                  + kd_loss(out_t['score'], train_label, out_s['x_a_n_dec'], out_t['x_a_n_dec'])
                  + kd_loss(out_t['score'], train_label, out_s['x_t_n_dec'], out_t['x_t_n_dec'])
                  )
            loss = (mse_t + mae_t + rec_t + con_t + orth_t + kl_t
                    + mse_s + mae_s + rec_s + con_s + orth_s + kl_s
                    + kd) / BACKPROP_STEP
            loss.backward()

            if (step + 1) % BACKPROP_STEP == 0:
                optimizer_S.step()
                optimizer_S.zero_grad()

            RMSE_loss_s.append(mse_s.item())
            MAE_loss_s.append(mae_s.item())
            REC_loss_s.append(rec_s.item())
            CON_loss_s.append(con_s.item())
            ORTH_loss_s.append(orth_s.item())
            KL_loss_s.append(kl_s.item())
            RMSE_loss_t.append(mse_t.item())
            MAE_loss_t.append(mae_t.item())
            REC_loss_t.append(rec_t.item())
            CON_loss_t.append(con_t.item())
            ORTH_loss_t.append(orth_t.item())
            KL_loss_t.append(kl_t.item())
            KD_loss.append(kd.item())

            if (step + 1) % LOG_STEP == 0:
                mean_mae_loss_s = np.mean(MAE_loss_s)
                mean_rmse_loss_s = np.sqrt(np.mean(RMSE_loss_s))
                mean_rec_loss_s = np.mean(REC_loss_s)
                mean_con_loss_s = np.mean(CON_loss_s)
                mean_orth_loss_s = np.mean(ORTH_loss_s)
                mean_kl_loss_s = np.mean(KL_loss_s)
                mean_mae_loss_t = np.mean(MAE_loss_t)
                mean_rmse_loss_t = np.sqrt(np.mean(RMSE_loss_t))
                mean_rec_loss_t = np.mean(REC_loss_t)
                mean_con_loss_t = np.mean(CON_loss_t)
                mean_orth_loss_t = np.mean(ORTH_loss_t)
                mean_kl_loss_t = np.mean(KL_loss_t)
                mean_kd_loss = np.mean(KD_loss)
                print('Epoch: {:d}  Step: {:d} / {:d}'
                      ' | T: MAE: {:.4f}, RMSE: {:.4f} | S: MAE: {:.4f}, RMSE: {:.4f} | KD: {:.4f}'.format(
                     epoch + 1, step + 1, len(train_dataloader),
                      mean_mae_loss_t, mean_rmse_loss_t, mean_mae_loss_s, mean_rmse_loss_s, mean_kd_loss))
                writer.add_scalar('Train Loss/T MAE', mean_mae_loss_t, global_step=train_flags)
                writer.add_scalar('Train Loss/T RMSE', mean_rmse_loss_t, global_step=train_flags)
                writer.add_scalar('Train Loss/T REC', mean_rec_loss_t, global_step=train_flags)
                writer.add_scalar('Train Loss/T CON', mean_con_loss_t, global_step=train_flags)
                writer.add_scalar('Train Loss/T ORTH', mean_orth_loss_t, global_step=train_flags)
                writer.add_scalar('Train Loss/T KL', mean_kl_loss_t, global_step=train_flags)
                writer.add_scalar('Train Loss/S MAE', mean_mae_loss_s, global_step=train_flags)
                writer.add_scalar('Train Loss/S RMSE', mean_rmse_loss_s, global_step=train_flags)
                writer.add_scalar('Train Loss/S REC', mean_rec_loss_s, global_step=train_flags)
                writer.add_scalar('Train Loss/S CON', mean_con_loss_s, global_step=train_flags)
                writer.add_scalar('Train Loss/S ORTH', mean_orth_loss_s, global_step=train_flags)
                writer.add_scalar('Train Loss/S KL', mean_kl_loss_s, global_step=train_flags)
                writer.add_scalar('Train Loss/KD', mean_kd_loss, global_step=train_flags)
                train_flags += 1

        if (epoch + 1) % VAL_STEP == 0:
            model_S.eval()
            score_list = []
            pred_list = []
            with torch.no_grad():
                for step, (test_v_pack, test_a_pack, test_t_pack, _, score, d_flag) in enumerate(test_dataloader):
                    test_v_pack = test_v_pack.to(device).view(-1, test_v_pack.shape[2], 4647)
                    test_a_pack = test_a_pack.to(device).view(-1, test_a_pack.shape[2], 208)
                    test_t_pack = test_t_pack.to(device).reshape(-1, test_t_pack.shape[2], 768)
                    out = model_S(test_v_pack, test_a_pack, test_t_pack)
                    pred = torch.mean(out['score'].cpu().view(-1, 10), dim=1) * d_flag

                    score_list.append(score)
                    pred_list.append(pred)

                score_list = torch.cat(score_list)
                pred_list = torch.cat(pred_list)
                # torch.save(model_T.state_dict(), "weights/{}/model_T_{}.pth".format(STAGE, epoch + 1))
                torch.save(model_S.state_dict(), "weights/{}/model_S_{}.pth".format(STAGE, epoch + 1))
                test_logging(score_list, pred_list, writer, val_flags, val_split)
                val_flags += 1
