from torch.optim.lr_scheduler import ExponentialLR
import time
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
from torch.utils.data import Dataset, DataLoader
from model import * # models
from settings import * # settings


# Model's performance evaluation
def evaluation_pretrained(model, loader, criterion, seed, device):
    """
    Args: pytorch model, pytorch loader, seed
    Returns evaluation metrics
    """
    total = 0
    loss= 0 
    torch.manual_seed(seed) 
    for batch  in loader:
        model.eval()
        text = batch[0]
        categorical = batch[1]
        numerical = batch[2]
        y = batch[3]
        mask = batch[4]
        # to device
        text = text.to(device)
        mask = mask.to(device)
        categorical = categorical.to(device)
        numerical = numerical.to(device)
        y = y.to(device) 
        # predict
        with torch.no_grad():
            y_hat = model(text, mask, categorical, numerical)[0]
        # compute loss
        total += y.shape[0]
        loss += criterion(y_hat,y).item()*y.shape[0]
    return loss/total

# function for training the model
def training_pretrained(model, model_type, loader_train,  n_epochs, loader_validation, criterion, optimizer, factor, seed, verbose, device):
    
    # tracking variable
    best_val_perf = 0

    # scheduler
    scheduler = ExponentialLR(optimizer, gamma=factor)
    
    # training and validation loop
    for epoch in range(1, n_epochs+1):
        start=time.time()
        train_loss = 0 # training loss by sample
        total = 0 # number of samples
        torch.manual_seed(seed)
        for batch  in loader_train:
            model.train()
            text = batch[0]
            categorical = batch[1]
            numerical = batch[2]
            y = batch[3]
            mask = batch[4]
            
            # 1. clear gradients
            optimizer.zero_grad()
            
            # 2. to device
            text = text.to(device)
            mask = mask.to(device)
            categorical = categorical.to(device)
            numerical = numerical.to(device)
            y = y.to(device)
            
            # 3. forward pass and compute loss
            y_hat = model(text, mask, categorical, numerical)[0]
            loss = criterion(y_hat,y)
                
            # 4. backward pass
            loss.backward()
            
            # 5. optimization
            optimizer.step()
            
            # 6. record loss
            train_loss += loss.item()*y.shape[0]
            total += y.shape[0]

        end=time.time()   
        train_loss = train_loss/total
        if verbose:
            print("---------training time (s):", round(end-start,0), "---------")
            print("epoch:", epoch, "training loss:", round(train_loss,5))

        # model's performance evaluation (accuracy)
        val_loss = evaluation_pretrained(model, loader_validation, criterion, seed, device)
        validation_performance = performance_pretrained(model, loader_validation, model_type, seed, device)
        
        # scheduler step
        scheduler.step()

        if verbose:
            print("epoch:", epoch, "validation loss:", round(val_loss,5), "validation performance:", round(validation_performance,5))
            
        # save best model so far
        if validation_performance > best_val_perf*1.001: # increase in accuracy should be greater than 0.1%
            torch.save(model, 'checkpoint.pt')
            best_val_perf = validation_performance
        else: # ends training
            break
            
    # load the last checkpoint with the best model        
    model = torch.load("checkpoint.pt")
        
    return model, epoch
        
# performance computation
def performance_pretrained(model, loader_target, model_type, seed, device):
    """Performance computation (accuracy)"""
    preds_list = []
    text_preds_list = []
    tabular_preds_list = []
    labels_list = []
    torch.manual_seed(seed)
    for batch in loader_target:
        # evaluation mode
        model.eval()
        # inputs and labels
        text = batch[0]
        categorical = batch[1]
        numerical = batch[2]
        y = batch[3]
        mask = batch[4]
        # to device
        text = text.to(device)
        mask = mask.to(device)
        categorical = categorical.to(device)
        numerical = numerical.to(device)
        y = y.to(device)
        # prediction
        with torch.no_grad():
            pred = model(text, mask, categorical, numerical.float())[0]
        # compute softmax probabilities
        p_hat = F.softmax(pred, dim=1)  
        preds_list.append(p_hat)
        labels_list.append(y)

    labels_list = torch.cat(labels_list)
    preds_list = torch.cat(preds_list)

        
    performance = sum(torch.argmax(preds_list, dim=1)==labels_list).item()/labels_list.shape[0] # accuracy
    
    return performance    
    
    



        
  

  
