import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from model.Transnet import SourceNet, TargetNet
from utils import predict_mse_for_transnet, predict_mse_with_IDemb, load_word_embedding, Review_Dataset, Transnet_Dataset
from baseline_config import Config
from datetime import datetime


def train(train_dataloader, valid_dataloader, model_S, model_T, config, model_path):
   

    opt_S = torch.optim.Adam(model_S.parameters(), config.learning_rate, weight_decay=config.l2_regularization)
    opt_trans = torch.optim.Adam(model_S.trans_param(), config.learning_rate, weight_decay=config.l2_regularization)
    #opt_trans = torch.optim.Adam(model_S.module.trans_param(), config.learning_rate, weight_decay=config.l2_regularization)
    opt_T = torch.optim.Adam(model_T.parameters(), config.learning_rate, weight_decay=config.l2_regularization)
    lr_sch_S = torch.optim.lr_scheduler.ExponentialLR(opt_S, config.learning_rate_decay)
    lr_sch_trans = torch.optim.lr_scheduler.ExponentialLR(opt_trans, config.learning_rate_decay)
    lr_sch_T = torch.optim.lr_scheduler.ExponentialLR(opt_T, config.learning_rate_decay)

    best_loss, batch_step = 100, 0

    
    patience_limit = config.patience_limit 
    patience_check = 0  


    model_T.train()

    for epoch in range(config.train_epochs):

        model_S.train()  # turn on the train
        total_loss, total_samples = 0, 0

        for batch in train_dataloader:

            user_reviews, item_reviews, reviews, ratings= [x.to(config.device) for x in batch]
            # step 1: Train Target Network on the actual review.
            latent_T, pred_T = model_T(reviews)
            loss_T = F.l1_loss(pred_T, ratings)
            opt_T.zero_grad()
            loss_T.backward()

            # step 2: Learn to Transform.
            latent_S, pred_S = model_S(user_reviews, item_reviews)
            loss_trans = F.mse_loss(latent_S, latent_T.detach())
            opt_trans.zero_grad()
            loss_trans.backward()

            # step 3: Train a predictor on the transformed input.
            loss_S = F.l1_loss(pred_S, ratings, reduction='sum')
            opt_S.zero_grad()
            loss_S.backward()

            opt_T.step()
            opt_trans.step()
            opt_S.step()

            batch_step += 1
            total_loss += loss_S.item()  # summing over all loss of source network
            total_samples += len(pred_S)


        lr_sch_S.step()
        lr_sch_trans.step()
        lr_sch_T.step()

        model_S.eval()
        valid_mse = predict_mse_for_transnet(model_S, valid_dataloader, config.device)

        train_loss = total_loss / total_samples
        print(f"#### Epoch {epoch:3d}; train mse {train_loss:.6f}; validation mse {valid_mse:.6f}")


        if best_loss > valid_mse:
            best_loss = valid_mse
            torch.save(model_S, model_path)

            patience_check = 0  

        else : 
            patience_check += 1

            if patience_check > patience_limit :

              print('early stopping!')
              break


def test(dataloader, model, config):

    test_loss = predict_mse_for_transnet(model, dataloader, config.device)
    return test_loss




if __name__ == '__main__':

    config = Config()

    print('## Load dataset')
    word_emb, word_dict = load_word_embedding(config.word2vec_file)


    train_dataset = Transnet_Dataset(config.train_file, word_dict, config)
    valid_dataset = Transnet_Dataset(config.valid_file, word_dict, config)
    test_dataset =  Transnet_Dataset(config.test_file, word_dict, config)


    train_dlr = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    valid_dlr = DataLoader(valid_dataset, batch_size=config.batch_size)
    test_dlr = DataLoader(test_dataset, batch_size=config.batch_size)

    source_model = SourceNet(config, word_emb).to(config.device)
    target_model = TargetNet(config, word_emb).to(config.device)


    # model : DataParallel

    # _source = SourceNet(config, word_emb).cuda()
    # source_model = nn.DataParallel(_source).to(config.device)

    # _target = TargetNet(config, word_emb).cuda()
    # target_model = nn.DataParallel(_target).to(config.device)


    del train_dataset, valid_dataset, test_dataset, word_emb, word_dict



    dataset_name = Config.train_file.split('/')[-1].split('_')[1]
    time = datetime.today().strftime("%m%d_%H:%M")

    model_file = f'model/model_save/transnet_{dataset_name}_reviews_{time}.pt'


    train(train_dlr, valid_dlr, source_model, target_model, config, model_file)
    valid_loss = test(valid_dlr, torch.load(model_file), config)
    test_loss  = test(test_dlr, torch.load(model_file), config)


    #train(train_dlr, valid_dlr, source_model, target_model, config, config.model_file)
    #test(test_dlr, torch.load(config.model_file))


    print(f"valid_ mse is {valid_loss:.6f}")
    print(f"test mse is {test_loss:.6f}")