import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from model.DeepCoNN import DeepCoNN
from model.Dattn import Dattn 
from model.DAML import DAML
from utils import predict_mse, predict_mse_with_IDemb, load_word_embedding, Review_Dataset
from baseline_config import Config
from datetime import datetime




# 학습 
def train(train_dataloader, valid_dataloader, model, config, model_file):
    

    # optimizer
    opt = torch.optim.Adam(model.parameters(), config.learning_rate, weight_decay=config.l2_regularization)
    # scheduler
    lr_sch = torch.optim.lr_scheduler.ExponentialLR(opt, config.learning_rate_decay)

    best_loss = 100 

    patience_limit = config.patience_limit 
    patience_check = 0  
    
    for epoch in range(config.train_epochs):

        model.train() 
        total_loss, total_samples = 0, 0

        if config.with_idemb : # with_idemb : True

            for batch in train_dataloader:    
                
                uid, iid, user_vocablist, item_vocablist, ratings = map(lambda x: x.to(config.device), batch)
                predict = model(uid, iid, user_vocablist, item_vocablist)
                loss = F.mse_loss(predict, ratings, reduction='sum')  
                
                opt.zero_grad()  
                loss.backward()  
                opt.step()  

                total_loss += loss.item()
                total_samples += len(predict)

            lr_sch.step()
            model.eval()  
            
            valid_mse = predict_mse_with_IDemb(model, valid_dataloader, config.device)
            train_loss = total_loss / total_samples

            print(f"Epoch {epoch+1:3d}; train mse {train_loss:.6f}; validation mse {valid_mse:.6f}")

        else :          # with_idemb : False
        
            for batch in train_dataloader:

                user_vocablist, item_vocablist, ratings = map(lambda x: x.to(config.device), batch)
                predict = model(user_vocablist, item_vocablist)
                loss = F.mse_loss(predict, ratings, reduction='sum')  
            
                # 학습
                opt.zero_grad()  
                loss.backward()  
                opt.step()  

                total_loss += loss.item()
                total_samples += len(predict)

            lr_sch.step()
            model.eval()  
        
            valid_mse = predict_mse(model, valid_dataloader, config.device)
            train_loss = total_loss / total_samples
            
            print(f"Epoch {epoch+1:3d}; train mse {train_loss:.6f}; validation mse {valid_mse:.6f}")

        if best_loss > valid_mse:
            best_loss = valid_mse
            torch.save(model, model_file)

            patience_check = 0  

        else : 
            patience_check += 1

            if patience_check > patience_limit :

              print('early stopping!')
              break

def test(dataloader, model, config):

   
    if config.with_idemb :
        test_loss = predict_mse_with_IDemb(model, dataloader, config.device)

        return test_loss
        


    else :  
        test_loss = predict_mse(model, dataloader, config.device)

        return test_loss


if __name__ == '__main__':

    config = Config()

    print('## Load dataset')
    word_emb, word_dict = load_word_embedding(config.word2vec_file)

    
    train_dataset = Review_Dataset(config.train_file, word_dict, config)
    valid_dataset = Review_Dataset(config.valid_file, word_dict, config)
    test_dataset =  Review_Dataset(config.test_file, word_dict, config)
    
    
    train_dlr = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    valid_dlr = DataLoader(valid_dataset, batch_size=config.batch_size)
    test_dlr = DataLoader(test_dataset, batch_size=config.batch_size)

    baseline = config.baseline

    # model
    model = baseline(config, word_emb).to(config.device)  
    model_name = model.__class__.__name__

    print(f'model_name : {model_name}')

    # model : DataParallel
    # _model = Ournet(config, word_emb).cuda()
    # model = nn.DataParallel(_model).to(config.device)
    # model_name = _model.__class__.__name__



    del train_dataset, valid_dataset, test_dataset, word_emb, word_dict

    # os.makedirs(os.path.dirname(config.model_file), exist_ok=True)

    dataset_name = Config.train_file.split('/')[-1].split('_')[1]
    time = datetime.today().strftime("%m%d_%H:%M")

    model_file = f'model/model_save/{model_name}_{dataset_name}_reviews_{time}.pt'


    train(train_dlr, valid_dlr, model, config, model_file)
    valid_loss = test(valid_dlr, torch.load(model_file), config)
    test_loss  = test(test_dlr, torch.load(model_file), config)

    print(f"valid_ mse is {valid_loss:.6f}")
    print(f"test mse is {test_loss:.6f}")
