import math
import datetime


import torch.nn as nn
import torch
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
import sys
import time




def loss_BCE(y_pred, y):
    """
    Binary cross entropy loss
    """
    # get the loss
    loss = nn.BCEWithLogitsLoss()(y_pred, y.float())
    
    return loss

def train_model_GDRO(num_epochs, model, loader_dict, train_loader, val_loader, device, optimizer, eta , per_step=25, early_stopping= True, orig_patience=1, tol = 0.001, save_best_model=True, model_name='GDRO_model',  verbose=True, save_at_steps=None):
    
    # total steps to take
    total_step_train = len(loader_dict[train_loader])
    total_step_val = len(loader_dict[val_loader])

    # placeholders to track loss
    best_loss = math.inf
    patience = orig_patience 
    best_epoch = 0

    # go over each epoch
    for epoch in range(num_epochs):
     
      # Each epoch has a training and validation phase
      for phase in [train_loader, val_loader]:

        total_samples = 0
        total_correct = 0

        if phase == train_loader:
           total_loss_train = 0
        else:
           total_loss_val = 0
       
        # go over batches, set model to train or eval mode
        if phase == train_loader:
          model.train()  # Set model to training mode
        else:
          model.eval()   # Set model to evaluate mode

        # go over batches
        for i, (x, y, g) in enumerate(loader_dict[phase]):          


          # get the data
          if x.shape[0] == 1:
             x = x.squeeze(0)
          if y.shape[0] == 1:
             y = y.transpose(0, 1)
       
          # if training, set step 
          if phase == train_loader:

            # get loss per group
            loss_per_group_at_t = model.weights_obj.group_loss(x, y, g)

            # get the weights 
            q_t =  model.weights_obj.update_DRO_weights(q_t, loss_per_group_at_t, eta  )

            # reset the weights of the model
            model.weights_obj.reset_p_weights(q_t)

            # get the sample weights
            w_t = model.weights_obj.get_weights_sample(g, q_t)
            
            # get the weighted loss
            loss_t = model.loss(x, y, w_t)
            
            # calc grad
            loss_t.backward()
               
            # take step
            optimizer.step()
                
            # add to loss
            total_loss_train += loss_t.item()
                    
            # clear gradients for this step   
            optimizer.zero_grad()  
    
            # add to total loss
            total_loss_train += loss_t.item()
              
          else:
            
            # save memory during validation
            with torch.no_grad():
               
              # get the sample weights
              w_t = model.weights_obj.get_weights_sample(g, q_t)
            
              # get the weighted loss
              loss_t = model.loss(x, y, w_t)

            # add to total loss
            total_loss_val += loss_t.item()  
      
          # get the pred
          y_pred =  model(x.float())
          classes =  torch.round(torch.sigmoid(y_pred))
          correct = (classes == y.to(device)).sum()
          sample_size = len(y)
          print('Accuracy at batch: ', correct/sample_size)
                 
               
          
          # running amount correct & samples considered
          total_correct += correct
          
          # total samples 
          total_samples += sample_size
          

          # each .. batches, show progress
          if (i+1) % per_step == 0:

            if phase == train_loader:
              total_step = total_step_train
            else:
              total_step = total_step_val 

            if verbose:
              print ('Phase : {}, Epoch [{}/{}], Step [{}/{}], Worst group Loss (at batch): {:.4f},  accuracy: {:.4f}' 
                            .format(phase,
                                    epoch + 1, # current epoch
                                    num_epochs, # total epochs
                                    i + 1, # current steps
                                    total_step, # total steps
                                    loss_t.item(), # loss for this batch
                                    total_correct/total_samples,
                              )
                            )
      # save the model at particular steps
      if save_at_steps is not None:
        if epoch in save_at_steps:
           torch.save(model.state_dict(), model_name + '_parameters_at_step_{}'.format(epoch))
            
            
       # if early stopping, check if we should stop
      if early_stopping:
          
        avg_loss_val = total_loss_val/total_step_val

        if avg_loss_val > (best_loss - tol):
          patience = patience - 1
          if patience ==0:
              print('Early stopping at epoch{}: current loss {}  > {}, best loss.'.format(epoch, avg_loss_val, best_loss))
              last_epoch = epoch

              if save_best_model:
                 print('loading best param from epoch {}'.format(best_epoch))
                 model.load_state_dict(torch.load(model_name + '_current_best_model_parameters.pt'))
              break
          else:
              print('Not improving at epoch {}, current loss {} > {}, patience is now {}'.format(epoch + 1, avg_loss_val, best_loss, patience))
        else:
          print('Improving: current loss {} < {}, loss last epoch '.format(avg_loss_val, best_loss))
          best_loss = avg_loss_val
          patience = orig_patience
          
          if save_best_model:
             print('save best param at epoch {}'.format(epoch))
             best_epoch = epoch
             torch.save(model.state_dict(), model_name + '_current_best_model_parameters.pt') 
             
      # if we are done, break
      else:
          avg_loss_train = total_loss_train/total_step_train

          if avg_loss_train > (best_loss - tol):
            patience = patience - 1
            if patience ==0:
                print('Converged based on train loss at epoch{}: current loss {}  > {}, best loss. Loading best parameters'.format(epoch, avg_loss_train, best_loss))

                if save_best_model:
                   print('loading best param from epoch {}'.format(best_epoch))
                   model.load_state_dict(torch.load(model_name + '_current_best_model_parameters.pt'))
                break
            else:
                print('Not improving at epoch {}, current loss + tolerance {} > {}, patience is now {}'.format(epoch + 1, avg_loss_train + tol, best_loss, patience))
          else:
            print('Improving: current loss {} < {}, loss last epoch '.format(avg_loss_train, best_loss))
            best_loss = avg_loss_train
            patience = orig_patience
            
            
            if save_best_model:
               print('save param at epoch {}'.format(epoch))
               best_epoch = epoch
               torch.save(model.state_dict(), model_name + '_current_best_model_parameters.pt') 
               
def create_loader(X, y, batch_size, workers, shuffle=True, include_weights=False, weights=None, pin_memory=True, g=None):
        """ 
        Create train and validation loaders
        """

        # create the dataset
        if g is not None:
            data = TensorDataset(X, y, g)
        else:
            data = TensorDataset(X, y)
          
        
        # if include weights,add weights to the sampler
        if include_weights:
            sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples = len(weights))

        # create the loaders
        loader = DataLoader(data, 
                                        batch_size=batch_size,
                                         shuffle=shuffle, 
                                         sampler =sampler,
                                         num_workers=workers,
                                         pin_memory=pin_memory)
        
        return loader


        

def train_model(num_epochs, model, loader_dict, train_loader, val_loader, device, optimizer, loss_fn, per_step=25, early_stopping= True, orig_patience=1, tol = 0.001, save_best_model=True, model_name='model',  verbose=True, save_at_steps=None, scheduler=None):
    
    # total steps to take
    total_step_train = len(loader_dict[train_loader])
    total_step_val = len(loader_dict[val_loader])

    # placeholders to track loss
    best_loss = math.inf
    patience = orig_patience 
    best_epoch = 0

    # go over each epoch
    for epoch in range(num_epochs):
     
      # Each epoch has a training and validation phase
      for phase in [train_loader, val_loader]:

        total_samples = 0
        total_correct = 0

        if phase == train_loader:
           total_loss_train = 0
        else:
           total_loss_val = 0
       
        # go over batches, set model to train or eval mode
        if phase == train_loader:
          model.train()  # Set model to training mode
        else:
          model.eval()   # Set model to evaluate mode

        # go over batches
       
        for i, (x, y) in enumerate(loader_dict[phase]):        
          


          # get the data
          if x.shape[0] == 1:
             x = x.squeeze(0)
          if y.shape[0] == 1:
             y = y.transpose(0, 1)
       
          # if training, set step 
          if phase == train_loader:

            # get output
            y_pred =  model(x.float())

            # calc the loss
            loss =loss_fn(y_pred, y)
            
            # calc grad
            loss.backward()
               
            # take step
            optimizer.step()

            # if scheduler is not None, step
            if scheduler is not None:
                scheduler.step()
                
            # add to loss
            total_loss_train += loss.item()
                    
            # clear gradients for this step   
            optimizer.zero_grad()  
    
            # add to total loss
            total_loss_train += loss.item()
              
          else:
            
            # save memory during validation
            with torch.no_grad():
              # get output
              y_pred =  model(x.float())
                    
              # calc the loss
              loss =loss_fn(y_pred, y)

            # add to total loss
            total_loss_val += loss.item()  
      
          # if we use the BCEwithLogitsLoss, we need to round the output to get the classes
          if loss_fn == nn.BCEWithLogitsLoss() or loss_fn == loss_BCE:
            classes =  torch.round(torch.sigmoid(y_pred))
          elif loss_fn == nn.CrossEntropyLoss():
            classes = torch.argmax(y_pred, dim=1)
          elif loss_fn == nn.MSELoss():
            # if more than 0.5, then 1, else 0
            classes = 1*(y_pred > 0.5)
          else:
            classes = y_pred
          correct = (classes == y.to(device)).sum()
          sample_size = len(y)
                 
               
          
          # running amount correct & samples considered
          total_correct += correct
          
          # total samples 
          total_samples += sample_size
          

          # each .. batches, show progress
          if (i+1) % per_step == 0:

            if phase == train_loader:
              total_step = total_step_train
            else:
              total_step = total_step_val 

            if verbose:
              print ('Phase : {}, Epoch [{}/{}], Step [{}/{}], Loss (at batch): {:.4f},  accuracy: {:.4f}' 
                            .format(phase,
                                    epoch + 1, # current epoch
                                    num_epochs, # total epochs
                                    i + 1, # current steps
                                    total_step, # total steps
                                    loss.item(), # loss for this batch
                                    total_correct/total_samples,
                              )
                            )
      # save the model at particular steps
      if save_at_steps is not None:
        if epoch in save_at_steps:
           torch.save(model.state_dict(), model_name + '_parameters_at_step_{}'.format(epoch))
            
       # if early stopping, check if we should stop
      if early_stopping:
          
        avg_loss_val = total_loss_val/total_step_val

        if avg_loss_val > (best_loss - tol):
          patience = patience - 1
          if patience ==0:
              print('Early stopping at epoch{}: current loss {}  > {}, best loss.'.format(epoch, avg_loss_val, best_loss))
              last_epoch = epoch

              if save_best_model:
                 print('loading best param from epoch {}'.format(best_epoch))
                 model.load_state_dict(torch.load(model_name + '_current_best_model_parameters.pt'))
              break
          else:
              print('Not improving at epoch {}, current loss {} > {}, patience is now {}'.format(epoch + 1, avg_loss_val, best_loss, patience))
        else:
          print('Improving: current loss {} < {}, loss last epoch '.format(avg_loss_val, best_loss))
          best_loss = avg_loss_val
          patience = orig_patience
          
          if save_best_model:
             print('save best param at epoch {}'.format(epoch))
             best_epoch = epoch
             torch.save(model.state_dict(), model_name + '_current_best_model_parameters.pt') 
             
      # if we are done, break
      else:
          avg_loss_train = total_loss_train/total_step_train

          if avg_loss_train > (best_loss - tol):
            patience = patience - 1
            if patience ==0:
                print('Converged based on train loss at epoch{}: current loss {}  > {}, best loss. Loading best parameters'.format(epoch, avg_loss_train, best_loss))

                if save_best_model:
                   print('loading best param from epoch {}'.format(best_epoch))
                   model.load_state_dict(torch.load(model_name + '_current_best_model_parameters.pt'))
                break
            else:
                print('Not improving at epoch {}, current loss + tolerance {} > {}, patience is now {}'.format(epoch + 1, avg_loss_train + tol, best_loss, patience))
          else:
            print('Improving: current loss {} < {}, loss last epoch '.format(avg_loss_train, best_loss))
            best_loss = avg_loss_train
            patience = orig_patience
            
            
            if save_best_model:
               print('save param at epoch {}'.format(epoch))
               best_epoch = epoch
               torch.save(model.state_dict(), model_name + '_current_best_model_parameters.pt') 
               





def train_BERT(num_epochs, model,  loader_dict, train_loader, val_loader, device, optimizer, loss_fn, per_step=25, early_stopping= True, orig_patience=1, tol = 0.001, save_best_model=True, model_name='model',  verbose=True, save_at_steps=None, scheduler=None):
    
    # total steps to take
    total_step_train = len(loader_dict[train_loader])
    total_step_val = len(loader_dict[val_loader])
    print('Total steps train: ', total_step_train)
    print('Total steps val: ', total_step_val)

    # placeholders to track loss
    best_loss = math.inf
    patience = orig_patience 
    best_epoch = 0

    # go over each epoch
    for epoch in range(num_epochs):
     
      # Each epoch has a training and validation phase
      for phase in [train_loader, val_loader]:

        total_samples = 0
        total_correct = 0

        if phase == train_loader:
           total_loss_train = 0
        else:
           total_loss_val = 0
       
        # go over batches, set model to train or eval mode
        if phase == train_loader:
          model.train()  # Set model to training mode
        else:
          model.eval()   # Set model to evaluate mode

        # go over batches
        for i, (x, y) in enumerate(loader_dict[phase]):
          start_time = time.time()  
          
          # move to device
          x = x.to(device)
          y = y.to(device)


          # if loss is bce, we need to add dim to y
          if loss_fn == loss_BCE:
            y = y.unsqueeze(1)
     

          # get the data
          if x.shape[0] == 1:
             x = x.squeeze(0)
          if y.shape[0] == 1:
             y = y.transpose(0, 1)
       
          # if training, set step 
          if phase == train_loader:
            
            # get output
            logits = model(x)

            # calc the loss
            loss = loss_fn(logits, y)

            # get output
            y_pred =  model(x)

            # calc grad
            loss.backward()
               
            # take step
            optimizer.step()

            # if scheduler is not None, step
            if scheduler is not None:
                scheduler.step()
                
            # add to loss
            total_loss_train += loss.item()
                    
            # clear gradients for this step   
            optimizer.zero_grad()  
    
            # add to total loss
            total_loss_train += loss.item()
              
          else:
            
            # save memory during validation
            with torch.no_grad():
              # get output
              logits = model(x)

              # calc the loss
              loss = loss_fn(logits, y)
                    
          
            # add to total loss
            total_loss_val += loss.item()  
      
          # if we use the BCEwithLogitsLoss, we need to round the output to get the classes
          if loss_fn == nn.BCEWithLogitsLoss() or loss_fn == loss_BCE:
            classes =  torch.round(torch.sigmoid(logits))
            print('p(y_pred = 1 in batch): ',classes.mean())
          elif loss_fn == nn.CrossEntropyLoss():
            classes = torch.argmax(logits, dim=1)
          elif loss_fn == nn.MSELoss():
            # if more than 0.5, then 1, else 0
            classes = 1*(logits > 0.5)
          else:
            classes = y_pred
          correct = (classes == y.to(device)).sum()
          sample_size = len(y)

          print('p(y = 1 in batch): ', y.mean())


          
          # running amount correct & samples considered
          total_correct += correct
          
          # total samples 
          total_samples += sample_size
          print('total samples: ', total_samples)
          

          # each .. batches, show progress
          if (i+1) % per_step == 0:

            if phase == train_loader:
              total_step = total_step_train
            else:
              total_step = total_step_val 

            if verbose:
              print ('Phase : {}, Epoch [{}/{}], Step [{}/{}], Loss (at batch): {:.4f},  accuracy: {:.4f}' 
                            .format(phase,
                                    epoch + 1, # current epoch
                                    num_epochs, # total epochs
                                    i + 1, # current steps
                                    total_step, # total steps
                                    loss.item(), # loss for this batch
                                    total_correct/total_samples,
                              )
                            )
          
          print('Time taken for batch: ', time.time() - start_time)
      # save the model at particular steps
      if save_at_steps is not None:
        if epoch in save_at_steps:
           torch.save(model.state_dict(), model_name + '_parameters_at_step_{}'.format(epoch))
      
       # if early stopping, check if we should stop
      if early_stopping:
          
        avg_loss_val = total_loss_val/total_step_val
        

        if avg_loss_val > (best_loss - tol):
          patience = patience - 1
          if patience ==0:
              print('Early stopping at epoch{}: current loss {}  > {}, best loss.'.format(epoch, avg_loss_val, best_loss))

              if save_best_model:
                 print('loading best param from epoch {}'.format(best_epoch))
                 model.load_state_dict(torch.load(model_name + '_current_best_model_parameters.pt'))
              break
          else:
              print('Not improving at epoch {}, current loss {} > {}, patience is now {}'.format(epoch + 1, avg_loss_val, best_loss, patience))
        else:
          print('Improving: current loss {} < {}, loss last epoch '.format(avg_loss_val, best_loss))
          best_loss = avg_loss_val
          patience = orig_patience
          
          if save_best_model:
             print('save best param at epoch {}'.format(epoch))
             best_epoch = epoch
             torch.save(model.state_dict(), model_name + '_current_best_model_parameters.pt') 
             
      # if we are done, break
      else:
          print('total loss train: ', total_loss_train)
          print('total step train: ', total_step_train)
          avg_loss_train = total_loss_train/total_step_train

          if avg_loss_train > (best_loss - tol):
            patience = patience - 1
            if patience ==0:
                print('Converged based on train loss at epoch{}: current loss {}  > {}, best loss. Loading best parameters'.format(epoch, avg_loss_train, best_loss))

                if save_best_model:
                   print('loading best param from epoch {}'.format(best_epoch))
                   model.load_state_dict(torch.load(model_name + '_current_best_model_parameters.pt'))
                break
            else:
                print('Not improving at epoch {}, current loss + tolerance {} > {}, patience is now {}'.format(epoch + 1, avg_loss_train + tol, best_loss, patience))
          else:
            print('Improving: current loss {} < {}, loss last epoch '.format(avg_loss_train, best_loss))
            best_loss = avg_loss_train
            patience = orig_patience
            
            
            if save_best_model:
               print('save param at epoch {}'.format(epoch))
               best_epoch = epoch
               torch.save(model.state_dict(), model_name + '_current_best_model_parameters.pt') 
               
        