from ...util.ProgressBar import PBar
from ...data.DataLoading import MyData, MyDataSingle

import numpy as np
import pandas as pd


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


 
    
class ModelFitter:
    '''
    This is an internal class that allows for fitting of models.
    '''
    def __init__(self, model, data_prepare_class = MyData):
        '''
        Arguments
        ---------
            model: pytorch model
                This is the pytorch model that will be fit

            device: string
                This is the device name that the model will be trained on. Most common
                arguments here will be 'cpu' or 'cuda'.

            data_prepare_class: torch.utils.data.Dataset class
                This is the class that will be used to batch the data.

        '''


        self.model = model
        self.train_loss = {}
        self.test_loss = {}
        self.predicted = {}
        self.n_trains = -1
        self.device = next(model.parameters()).device
        self.data_prepare_class = data_prepare_class

        return
    
    def prepare_data(self, X, Y, upsample_index = None):
        '''
        This puts the data into a form that torch can easily use using the DataLoader 
        function
        '''

        
        data_set = self.data_prepare_class(X, Y)

        return data_set

    def fit(self, X, Y, n_epochs, criterion, optimizer, data_params, 
        verbose = True, X_val = None, Y_val = None):
        '''
        This fits the model.

        Arguments
        ---------
            
            X: tensor
                This is a tensor of the data for the model to be trained on

            Y: tensor
                This is a tensor of the labels for the model to be trained on

            n_epochs: int
                This is the number of epochs
            
            criterion: torch nn loss function
                This is the loss function that will be used in the training
            
            optimizer: torch nn optimizer
                This is the optimisation method used in the training
            
            data_params: dictionary
                This is a dictionary containing parameters that can be parsed into
                the torch.utils.data.DataLoader function.
            
            verbose: bool
                For printing progress
            
            X_val: tensor
                This is the testing set that the model will be evaluated on as the model 
                is trained.
            
            Y_val: tensor
                This is the testing labels set that the model will be evaluated on as the model 
                is trained.
        Returns
        ---------
            model: pytorch model
                This is the trained model.


        '''
        self.n_trains += 1

        self.criterion = criterion
        self.optimizer = optimizer

        self.data_params = data_params

        if (not X_val is None) & (not Y_val is None): testing_too = True
        else: testing_too = False

        training_set = self.prepare_data(X, Y)
        training_generator = torch.utils.data.DataLoader(training_set, **data_params)
        
        if testing_too:
            testing_set = self.prepare_data(X_val, Y_val)
            testing_generator = torch.utils.data.DataLoader(testing_set, **data_params)

        train_loss_temp = []
        test_loss_temp = []
        
        epoch_bar = PBar(show_length = 20, n_iterations=n_epochs)
        print_threshold = 0

        self.model.train()

        for epoch in range(n_epochs):
            training_loss = 0

            test_loss = 0

            for nd, data in enumerate(training_generator):
                
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                self.optimizer.zero_grad()

                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

                loss.backward()
                self.optimizer.step()

                training_loss += loss.item() * inputs.size(0)



            epoch_loss = training_loss/training_set.__len__()
            train_loss_temp.append(epoch_loss)

            self.model.eval()
            if testing_too:
                with torch.no_grad():
                        for data in testing_generator:
                            inputs, labels = data
                            inputs, labels = inputs.to(self.device), labels.to(self.device)
                            outputs = self.model(inputs)
                            loss = self.criterion(outputs, labels)
                            test_loss += loss.item() * inputs.size(0)
                epoch_test_loss = test_loss/testing_set.__len__()
                test_loss_temp.append(epoch_test_loss)
            self.model.train()

            
            epoch_bar.update(1)
            bar = epoch_bar.give()
            if (epoch+1)/n_epochs >= print_threshold:
                printing_statement = 'Epochs: {}. epoch {} done. Loss per train sample: {:.2f}.'.format(bar,epoch+1,epoch_loss)
                if testing_too: 
                    printing_statement += ' Loss per test sample: {:.2f}.'.format(epoch_test_loss)
                if verbose: 
                    print(printing_statement)
                print_threshold += 0.2
        self.train_loss[self.n_trains] = np.asarray(train_loss_temp)
        self.test_loss[self.n_trains] = np.asarray(test_loss_temp)

        return self.model
    
    def predict(self, X):
        
        predicting_set = MyDataSingle(X)
        predicting_generator = torch.utils.data.DataLoader(predicting_set, **self.data_params) 

        with torch.no_grad():
            self.model.eval()
            for nd, data in enumerate(predicting_generator):
                inputs = data
                inputs = inputs.to(self.device)
                if nd == 0:
                    out = self.model(inputs).to('cpu')
                else:
                    out = torch.cat([out,self.model(inputs).to('cpu')], axis = 0)

        self.model.train()
        
        return out

    
    