from turtle import end_fill
import torch
from tqdm import tqdm
import os
import numpy as np
import sys
from models import TwoStepModel
import time


def train(model, dataloader, valloader, epochs = 30, hook = None):
    '''
        Trains model for self.num_iters 
            dataloader: pytorch geometric dataloader with training data
            valloader: pytorch geometric dataloader with validation data
            
    '''

    model.d_optimizer = torch.optim.Adam(model.model.parameters(), model.lr, weight_decay=0)

    device = torch.device(f"cuda:{model.config.gpu}" if torch.cuda.is_available() else 'cpu')
    print(device)
    model.device = device
    model.to(device)
    best_val_loss = np.inf
    epoch_length = len(dataloader)

    train_step = 0
    pred_loss = 0

    patience_limit = model.config.__dict__.get('patience_limit', 5)
    
    patience = 0
    prgbar = tqdm(range(epochs))
    epoch_times = []
    hook_result = None
    for epoch in prgbar:
        start_time=  time.process_time()
        train_loss_avg = 0
        #pbar = tqdm(dataloader, total=epoch_length)
        if patience >= patience_limit:
            break
        
        for epoch_step, graph_batch in enumerate(dataloader):
            
            graph_batch = graph_batch.to(device)
            
            model.d_optimizer.zero_grad()
            
            predictions = model.model(graph_batch, per_node_val_losses = model.per_node_val_losses)
           
            if isinstance(model.model,TwoStepModel.MaskedModel):
                no_last_predictions = predictions[1]
                predictions = predictions[0]
                no_last_loss = model.loss_func(no_last_predictions, graph_batch) 
            
            loss_val= model.loss_func(predictions, graph_batch) 
            pred_loss = loss_val.item()
            
            loss_val+= model.get_additional_loss_terms()

            if isinstance(model.model,TwoStepModel.MaskedModel):
                loss_val =(loss_val + no_last_loss)/2

            loss_val.backward()
            model.d_optimizer.step()
            
            train_loss_avg += loss_val.item()
        
        if train_step % 1 ==0:#model.config.validation_step == 0:
            val_losses = {}
            val_losses["prediction_train_loss"] = pred_loss
            
            model.per_node_val_losses, model.graph_val_losses = model.get_graph_and_per_node_losses(valloader)
            
            valloss= model.graph_val_losses.mean().item() 
            val_losses['prediction_val_loss'] = valloss
            prgbar.set_description(f"Valloss: {valloss}")
            print("validation scores :",val_losses, flush = True)
            
            if hook is not None :
                hook_result = hook(model)
                val_losses['hook_result']=hook_result
                    
            print("validation scores :",val_losses, flush = True)
            #print(valloss, best_val_loss,file = sys.stderr)
            if valloss-best_val_loss <-1e-4:
                model.save_model(os.path.join(model.config.model_save_dir,model.config.model_type+f"_{model.config.n_masks}.pth"))
                best_val_loss=valloss
                patience=0
                
            else:
                patience+=1
        train_step+=1
                    
        train_loss_avg /= len(dataloader)
        print("Epoch: ", epoch, ", training loss: ", train_loss_avg, flush = True)
        end_time = time.process_time()
        epoch_time = end_time - start_time
        epoch_times.append(epoch_time)
        
    print("best val loss", best_val_loss)

    if hook_result == "save_times":
        with open(f'{model.config.model_type}_{model.config.n_masks}_{model.config.task}_epoch_times.txt','w') as f:
            for epoch_time in epoch_times:
                f.write(f'{epoch_time}\n')
        
    return best_val_loss
