import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from model.Prompt2Rec import Ournet
from ablation.Lattn import Lattn
from ablation.Mattn import Mattn
from ablation.NoAttn import NoAttn

from utils import load_word_embedding, Vocab_Dataset, predict_mse_ournet
from ournet_config import Config
from datetime import datetime


# train
def train(train_dataloader, valid_dataloader, model, config, model_file):
    
 
    # optimizer
    opt = torch.optim.Adam(model.parameters(), config.learning_rate, weight_decay=config.l2_regularization)
    # scheduler
    lr_sch = torch.optim.lr_scheduler.ExponentialLR(opt, config.learning_rate_decay)

    best_loss = 100 

    # early stopping
    patience_limit = config.patience_limit 
    patience_check = 0 
    
    for epoch in range(config.train_epochs):

        model.train()  
        total_loss, total_samples = 0, 0


        for batch in train_dataloader:    
            
            uid, iid, user_vocablist, item_vocablist, ratings = map(lambda x: x.to(config.device), batch)
            predict, _, _ , _ , _ = model(uid, iid, user_vocablist, item_vocablist) 
            loss = F.mse_loss(predict, ratings, reduction='sum')  
            
            opt.zero_grad()  
            loss.backward()  
            opt.step()  

            total_loss += loss.item()
            total_samples += len(predict)

        lr_sch.step()
        model.eval()  
        
        valid_mse = predict_mse_ournet(model, valid_dataloader, config.device)
        train_loss = total_loss / total_samples

        print(f"Epoch {epoch+1:3d}; train mse {train_loss:.6f}; validation mse {valid_mse:.6f}")


        # save best model
        if best_loss > valid_mse:
            best_loss = valid_mse
            torch.save(model, model_file)

            patience_check = 0 

        else : 
            patience_check += 1

            if patience_check > patience_limit :

              print('early stopping!')
              break

# test
def test(dataloader, model, config):

    test_loss = predict_mse_ournet(model, dataloader, config.device)
    return test_loss
        


if __name__ == '__main__':

    config = Config()
    # print(config)
    print('## Load dataset')
    word_emb, word_dict = load_word_embedding(config.word2vec_file)
    
    train_dataset = Vocab_Dataset(config.train_file, word_dict, config)
    valid_dataset = Vocab_Dataset(config.valid_file, word_dict, config)
    test_dataset =  Vocab_Dataset(config.test_file, word_dict, config)
    
    
    train_dlr = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    valid_dlr = DataLoader(valid_dataset, batch_size=config.batch_size)
    test_dlr = DataLoader(test_dataset, batch_size=config.batch_size)


    # model
    model = Ournet(config, word_emb).to(config.device)  
    model_name = model.__class__.__name__

    # model : DataParallel
    # _model = Ournet(config, word_emb).cuda()
    # model = nn.DataParallel(_model).to(config.device)
    # model_name = _model.__class__.__name__


    del train_dataset, valid_dataset, test_dataset, word_emb, word_dict


    dataset_name = Config.train_file.split('/')[-1].split('_')[1]
    time = datetime.today().strftime("%m%d_%H:%M")

    model_file = f'model/model_save/{model_name}_{dataset_name}_vocab_{time}.pt'
    
    
    
    train(train_dlr, valid_dlr, model, config, model_file)
    valid_loss = test(valid_dlr, torch.load(model_file), config)
    test_loss  = test(test_dlr, torch.load(model_file), config)

    print(f"valid_ mse is {valid_loss:.6f}")
    print(f"test mse is {test_loss:.6f}")