import torch
import torch.nn as nn
import torch.nn.functional as F


import numpy as np
import torch.optim as optim
import logging
import os, sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from utils.flip_gradient import flip_gradient

class CT_trainer:
    def __init__(self, model, hyperparams, params, balancing="domain_confusion"):
        self.model = model
        self.balancing = balancing
        
        # self.num_treatments = params['num_treatments']
        # self.num_covariates = params['num_covariates']
        # self.num_outputs = params['num_outputs']
        # self.max_sequence_length = params['max_sequence_length']
        self.num_epochs = params['num_epochs']

        # self.br_size = hyperparams['br_size']
        # self.rnn_hidden_units = hyperparams['rnn_hidden_units']
        # self.fc_hidden_units = hyperparams['fc_hidden_units']
        self.batch_size = hyperparams['batch_size']
        # self.rnn_keep_prob = hyperparams['rnn_keep_prob']
        self.learning_rate = hyperparams['learning_rate']

        self.b_train_decoder = False

        # current_covariates = torch.Tensor(1, self.max_sequence_length, self.num_covariates)

        # # Initial previous treatment needs to consist of zeros (this is done when building the feed dictionary)
        # previous_treatments = torch.Tensor(1, self.max_sequence_length, self.num_treatments)
        # # self.current_treatments = torch.Tensor(1, self.max_sequence_length, self.num_treatments)
        # # self.outputs = torch.Tensor(1, self.max_sequence_length, self.num_outputs)
        # # self.active_entries = torch.Tensor(1, self.max_sequence_length, self.num_outputs)

        # self.init_state = None
        # if self.b_train_decoder:
        #     self.init_state = torch.Tensor(1, self.rnn_hidden_units)

        # self.alpha = torch.Tensor([])  # Gradient reversal scalar
        # self.rnn_input = torch.cat([current_covariates, previous_treatments], dim=-1)
        # self.rnn_cell = nn.LSTM(self.rnn_input.shape[-1],self.rnn_hidden_units, batch_first=True)
        # self.layer1 = nn.Linear(self.rnn_hidden_units, self.br_size)
        # self.treatments_network_layer = nn.Linear(self.br_size, self.fc_hidden_units)
        # self.treatment_prediction_layer = nn.Linear(self.fc_hidden_units, self.num_treatments)
        # self.outcome_network_layer = nn.Linear(self.br_size + self.num_treatments, self.fc_hidden_units)
        # self.outcome_prediction_layer = nn.Linear(self.fc_hidden_units, self.num_outputs) 

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(device)
        self.device = device
        
        
    # def build_balancing_representation(self, init_state, current_covariates, previous_treatments):
    #     rnn_input = torch.cat([current_covariates, previous_treatments], dim=-1)
    #     sequence_length = self.compute_sequence_length(rnn_input)

        

    #     if self.b_train_decoder:
    #         decoder_init_state = (init_state.unsqueeze(0), init_state.unsqueeze(0))#torch.cat([init_state, init_state], dim=-1)
    #     else:
    #         decoder_init_state = None

    #     rnn_output, _ = self.rnn_cell(
    #         rnn_input,
    #         decoder_init_state)
    #         # self.sequence_length)

    #     if self.rnn_keep_prob < 1.0:
    #         dropout = nn.Dropout(self.rnn_keep_prob)
    #         rnn_output = dropout(rnn_output)

        

    #     # Flatten to apply same weights to all time steps.
    #     rnn_output = rnn_output.reshape(-1, self.rnn_hidden_units)
    #     balancing_representation = self.layer1(rnn_output)
    #     balancing_representation = F.elu(balancing_representation)

    #     return balancing_representation
    
    # def build_treatment_assignments_one_hot(self, balancing_representation, alpha):
    #     balancing_representation_gr = flip_gradient(balancing_representation, alpha)

        
    #     out = F.elu(self.treatments_network_layer(balancing_representation_gr))

    #     treatment_logit_predictions = self.treatment_prediction_layer(out)
    #     treatment_prob_predictions = F.softmax(treatment_logit_predictions, dim=-1)

    #     return treatment_prob_predictions
    
    # def build_outcomes(self, current_treatments, balancing_representation):
    #     current_treatments_reshape = current_treatments.view(-1, self.num_treatments)

    #     outcome_network_input = torch.cat([balancing_representation, current_treatments_reshape], dim=-1)

    #     outcome_predictions = self.outcome_prediction_layer(F.elu(self.outcome_network_layer(outcome_network_input)))

    #     return outcome_predictions
    
    # def compute_full_loss(self,batch_current_covariates, batch_previous_treatments, batch_current_treatments, batch_init_state,
    #             batch_outputs, batch_active_entries, alpha_current, device):
    #     alpha_current_tensor = torch.Tensor([alpha_current]).to(device)
    #     batch_current_covariates = batch_current_covariates.to(device)
    #     batch_previous_treatments = batch_previous_treatments.to(device)
    #     batch_current_treatments = batch_current_treatments.to(device)
    #     if batch_init_state is not None:
    #         batch_init_state = batch_init_state.to(device)
    #     batch_outputs = batch_outputs.to(device)
    #     batch_active_entries = batch_active_entries.to(device)
        
        
    #     balancing_representation = self.build_balancing_representation(batch_init_state, batch_current_covariates, batch_previous_treatments)
    #     treatment_prob_predictions = self.build_treatment_assignments_one_hot(balancing_representation, alpha_current_tensor)
    #     predictions = self.build_outcomes(batch_current_treatments, balancing_representation)

    #     training_loss_treatments = self.compute_loss_treatments_one_hot(target_treatments=batch_current_treatments,
    #                                                                 treatment_predictions=treatment_prob_predictions,
    #                                                                 active_entries=batch_active_entries)
    #     training_loss_outcomes = self.compute_loss_predictions(batch_outputs, predictions, batch_active_entries)
    #     training_loss = training_loss_outcomes + training_loss_treatments
    #     return training_loss_outcomes, training_loss_treatments, training_loss
    
    def train_models(self, dataset_train, dataset_val, model_name, model_folder):
        optimizer = self.get_optimizer()

        

        for epoch in range(self.num_epochs):
            p = float(epoch) / float(self.num_epochs)
            alpha_current = 2. / (1. + np.exp(-10. * p)) - 1

            iteration = 0
            for (batch_current_covariates, batch_previous_treatments, batch_current_treatments, batch_init_state,
                batch_outputs, batch_active_entries) in self.gen_epoch(dataset_train, batch_size=self.batch_size):

                batch = dict()
                batch['prev_treatments'] = batch_previous_treatments
                batch['vitals'] = None
                # vitals = batch['vitals'] if self.has_vitals else None
                batch['prev_outputs'] = batch_current_covariates[:,:,:1]
                batch['static_features'] = batch_current_covariates[:,0,1:]
                batch['current_treatments'] = batch_current_treatments 
                batch['active_entries'] = batch_active_entries
                optimizer.zero_grad()
                treatment_pred, outcome_pred, br = self.model.forward(batch)
                mse_loss = F.mse_loss(outcome_pred, batch_outputs, reduce=False)
                
                if self.balancing == 'grad_reverse':
                    bce_loss = self.model.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='predict')
                elif self.balancing == 'domain_confusion':
                    bce_loss = self.model.bce_loss(treatment_pred, batch['current_treatments'].double(), kind='confuse')
                    bce_loss = self.model.br_treatment_outcome_head.alpha * bce_loss
                else:
                    raise NotImplementedError()

                
                bce_loss = (batch['active_entries'].squeeze(-1) * bce_loss).sum() / batch['active_entries'].sum()
                mse_loss = (batch['active_entries'] * mse_loss).sum() / batch['active_entries'].sum()

                training_loss = bce_loss + mse_loss
                # training_loss_outcomes, training_loss_treatments,training_loss = self.compute_full_loss(batch_current_covariates, batch_previous_treatments, batch_current_treatments, batch_init_state,
                #     batch_outputs, batch_active_entries, alpha_current, self.device)
                
                training_loss.backward()
                # feed_dict = self.build_feed_dictionary(batch_current_covariates, batch_previous_treatments,
                #                                     batch_current_treatments, batch_init_state, batch_outputs,
                #                                     batch_active_entries,
                #                                     alpha_current_tensor)

                #     feed_dict = {self.current_covariates: batch_current_covariates,
                #  self.previous_treatments: new_batch_previous_treatments,
                #  self.current_treatments: batch_current_treatments,
                #  self.outputs: batch_outputs,
                #  self.active_entries: batch_active_entries,
                #  self.alpha: alpha_current}

                # training_loss, training_loss_outcomes, training_loss_treatments = self.forward(feed_dict)
                optimizer.step()

                iteration += 1

            logging.info(
                "Epoch {} out of {} | total loss = {} | outcome loss = {} | treatment loss = {} | current alpha = {} ".format(
                    epoch + 1, self.num_epochs, training_loss.item(), mse_loss.item(),
                    bce_loss.item(), alpha_current))

        # Validation loss
        # validation_loss, validation_loss_outcomes, \
        # validation_loss_treatments = self.compute_validation_loss(dataset_val)

        # validation_mse, _ = self.evaluate_predictions(dataset_val)

        # logging.info(
        #     "Epoch {} Summary| Validation total loss = {} | Validation outcome loss = {} | Validation treatment loss = {} | Validation mse = {}".format(
        #         epoch, validation_loss.item(), validation_loss_outcomes.item(), validation_loss_treatments.item(),
        #         validation_mse.item()))

        checkpoint_name = model_name + "_final"
        self.save_network(model_folder, checkpoint_name)

    def load_model(self, model_name, model_folder):
        checkpoint_name = model_name + "_final"
        load_path = os.path.join(model_folder, f"{checkpoint_name}.pth")
        logging.info(f"Restoring model from {load_path}")

        self.load_state_dict(torch.load(load_path))

        # self.balancing_representation = self.build_balancing_representation()
        # self.treatment_prob_predictions = self.build_treatment_assignments_one_hot(self.balancing_representation)
        # self.predictions = self.build_outcomes(self.balancing_representation)

        # if torch.cuda.is_available():
        #     device = torch.device("cuda")
        # else:
        #     device = torch.device("cpu")

        # self.model = nn.ModuleList([self.balancing_representation, self.treatment_prob_predictions, self.predictions])
        # self.model.to(device)

        
        # # checkpoint_path = os.path.join(model_folder, checkpoint_name)
        # self.load_network(self.model, model_folder, checkpoint_name)

        # return self.model

    def load_network(self, model, model_dir, checkpoint_name):
        load_path = os.path.join(model_dir, f"{checkpoint_name}.pth")
        logging.info(f"Restoring model from {load_path}")

        model.load_state_dict(torch.load(load_path))
        return model
    
    def build_feed_dictionary(self, batch_current_covariates, batch_previous_treatments,
                              batch_current_treatments, batch_init_state,
                              batch_outputs=None, batch_active_entries=None,
                              alpha_current=1.0, lr_current=0.01, training_mode=True):
        batch_size = batch_previous_treatments.shape[0]
        zero_init_treatment = torch.zeros([batch_size, 1, self.num_treatments]).to(self.device)
        new_batch_previous_treatments = torch.cat([zero_init_treatment, batch_previous_treatments], axis=1)

        if training_mode:
            if self.b_train_decoder:
                feed_dict = {self.current_covariates: batch_current_covariates,
                             self.previous_treatments: batch_previous_treatments,
                             self.current_treatments: batch_current_treatments,
                             self.init_state: batch_init_state,
                             self.outputs: batch_outputs,
                             self.active_entries: batch_active_entries,
                             self.alpha: alpha_current}

            else:
                feed_dict = {self.current_covariates: batch_current_covariates,
                             self.previous_treatments: new_batch_previous_treatments,
                             self.current_treatments: batch_current_treatments,
                             self.outputs: batch_outputs,
                             self.active_entries: batch_active_entries,
                             self.alpha: alpha_current}
        else:
            if self.b_train_decoder:
                feed_dict = {self.current_covariates: batch_current_covariates,
                             self.previous_treatments: batch_previous_treatments,
                             self.current_treatments: batch_current_treatments,
                             self.init_state: batch_init_state,
                             self.alpha: alpha_current}
            else:
                feed_dict = {self.current_covariates: batch_current_covariates,
                             self.previous_treatments: new_batch_previous_treatments,
                             self.current_treatments: batch_current_treatments,
                             self.alpha: alpha_current}

        return feed_dict
    
    def gen_epoch(self, dataset, batch_size, training_mode=True):
        dataset_size = dataset['current_covariates'].shape[0]
        num_batches = int(dataset_size / batch_size) + 1

        for i in range(num_batches):
            if (i == num_batches - 1):
                batch_samples = range(dataset_size - batch_size, dataset_size)
            else:
                batch_samples = range(i * batch_size, (i + 1) * batch_size)

            if training_mode:
                batch_current_covariates = dataset['current_covariates'][batch_samples, :, :].to(self.device)
                batch_previous_treatments = dataset['previous_treatments'][batch_samples, :, :].to(self.device)
                batch_current_treatments = dataset['current_treatments'][batch_samples, :, :].to(self.device)
                batch_outputs = dataset['outputs'][batch_samples, :, :].to(self.device)
                batch_active_entries = dataset['active_entries'][batch_samples, :, :].to(self.device)

                batch_init_state = None
                if self.b_train_decoder:
                    batch_init_state = dataset['init_state'][batch_samples, :].to(self.device)

                yield (batch_current_covariates, batch_previous_treatments, batch_current_treatments, batch_init_state,
                       batch_outputs, batch_active_entries)
            else:
                batch_current_covariates = dataset['current_covariates'][batch_samples, :, :].to(self.device)
                batch_previous_treatments = dataset['previous_treatments'][batch_samples, :, :].to(self.device)
                batch_current_treatments = dataset['current_treatments'][batch_samples, :, :].to(self.device)

                batch_init_state = None
                if self.b_train_decoder:
                    batch_init_state = dataset['init_state'][batch_samples, :].to(self.device)

                yield (batch_current_covariates, batch_previous_treatments, batch_current_treatments, batch_init_state)


    def compute_validation_loss(self, dataset):
        validation_losses = []
        validation_losses_outcomes = []
        validation_losses_treatments = []

        dataset_size = dataset['current_covariates'].shape[0]
        if (dataset_size > 10000):
            batch_size = 10000
        else:
            batch_size = dataset_size

        for (batch_current_covariates, batch_previous_treatments, batch_current_treatments, batch_init_state,
             batch_outputs, batch_active_entries) in self.gen_epoch(dataset, batch_size=batch_size):
            validation_loss, validation_loss_outcomes, validation_loss_treatments = self.compute_full_loss(batch_current_covariates, batch_previous_treatments, batch_current_treatments, batch_init_state,
                    batch_outputs, batch_active_entries, 1, self.device)
            # feed_dict = self.build_feed_dictionary(batch_current_covariates, batch_previous_treatments,
            #                                        batch_current_treatments, batch_init_state, batch_outputs,
            #                                        batch_active_entries)

            # validation_loss, validation_loss_outcomes, validation_loss_treatments = self.sess.run(
            #     [self.loss, self.loss_outcomes, self.loss_treatments],
            #     feed_dict=feed_dict)

            validation_losses.append(validation_loss.item())
            validation_losses_outcomes.append(validation_loss_outcomes.item())
            validation_losses_treatments.append(validation_loss_treatments.item())

        validation_loss = np.mean(np.array(validation_losses))
        validation_loss_outcomes = np.mean(np.array(validation_losses_outcomes))
        validation_loss_treatments = np.mean(np.array(validation_losses_treatments))

        return validation_loss, validation_loss_outcomes, validation_loss_treatments

    def get_balancing_reps(self, dataset):
        logging.info("Computing balancing representations.")

        dataset_size = dataset['current_covariates'].shape[0]
        balancing_reps = torch.zeros((dataset_size, self.max_sequence_length, self.br_size)).to(self.device)

        dataset_size = dataset['current_covariates'].shape[0]
        if (dataset_size > 10000):  # Does not fit into memory
            batch_size = 10000
        else:
            batch_size = dataset_size

        num_batches = int(dataset_size / batch_size) + 1

        batch_id = 0
        num_samples = 50
        self.eval()
        with torch.no_grad():
            for (batch_current_covariates, batch_previous_treatments,
                batch_current_treatments, batch_init_state) in self.gen_epoch(dataset, batch_size=batch_size,
                                                                            training_mode=False):
                # feed_dict = self.build_feed_dictionary(batch_current_covariates, batch_previous_treatments,
                #                                        batch_current_treatments, batch_init_state, training_mode=False)

                batch_current_covariates, batch_previous_treatments = batch_current_covariates.to(self.device), batch_previous_treatments.to(self.device)
                if batch_init_state is not None:
                    batch_init_state = batch_init_state.to(self.device)

                # Dropout samples
                total_predictions = 0#torch.zeros((batch_size, self.max_sequence_length, self.br_size))

                for sample in range(num_samples):
                    balancing_representation = self.build_balancing_representation(batch_init_state, batch_current_covariates, batch_previous_treatments)
                    # predicted_outputs = self.build_outcomes(batch_current_treatments, balancing_representation)    q
                    br_outputs = balancing_representation#self.sess.run(self.balancing_representation, feed_dict=feed_dict)
                    br_outputs = br_outputs.reshape((-1, self.max_sequence_length, self.br_size))
                    total_predictions += br_outputs

                total_predictions /= num_samples

                if (batch_id == num_batches - 1):
                    batch_samples = range(dataset_size - batch_size, dataset_size)
                else:
                    batch_samples = range(batch_id * batch_size, (batch_id + 1) * batch_size)

                batch_id += 1
                balancing_reps[batch_samples] = total_predictions
        self.train()

        return balancing_reps

    def get_predictions(self, dataset):
        logging.info("Performing one-step-ahead prediction.")
        dataset_size = dataset['current_covariates'].shape[0]

        predictions = torch.zeros((dataset_size, self.max_sequence_length, self.num_outputs)).to(self.device)

        dataset_size = dataset['current_covariates'].shape[0]
        if dataset_size > 10000:
            batch_size = 10000
        else:
            batch_size = dataset_size

        num_batches = int(dataset_size / batch_size) + 1

        batch_id = 0
        num_samples = 50
        self.eval()
        with torch.no_grad():
            for (batch_current_covariates, batch_previous_treatments,
                batch_current_treatments, batch_init_state) in self.gen_epoch(dataset, batch_size=batch_size,
                                                                            training_mode=False):
                batch_current_treatments = batch_current_treatments.to(self.device)
                batch_current_covariates, batch_previous_treatments = batch_current_covariates.to(self.device), batch_previous_treatments.to(self.device)
                if batch_init_state is not None:
                    batch_init_state = batch_init_state.to(self.device)
                # feed_dict = self.build_feed_dictionary(batch_current_covariates, batch_previous_treatments,
                #                                     batch_current_treatments, batch_init_state, training_mode=False)

                # Dropout samples
                total_predictions = 0

                for sample in range(num_samples):
                    balancing_representation = self.build_balancing_representation(batch_init_state, batch_current_covariates, batch_previous_treatments)
                    predicted_outputs = self.build_outcomes(batch_current_treatments, balancing_representation)    
                    # predicted_outputs = self.model.predict(feed_dict)  # Assuming self.model is the PyTorch model
                    # predicted_outputs = np.reshape(predicted_outputs,
                    #                             newshape=(-1, self.max_sequence_length, self.num_outputs))
                    predicted_outputs = predicted_outputs.reshape((-1, self.max_sequence_length, self.num_outputs))
                    total_predictions += predicted_outputs

                total_predictions /= num_samples

                if batch_id == num_batches - 1:
                    batch_samples = range(dataset_size - batch_size, dataset_size)
                else:
                    batch_samples = range(batch_id * batch_size, (batch_id + 1) * batch_size)

                batch_id += 1
                predictions[batch_samples] = total_predictions

        self.train()

        return predictions

    def get_autoregressive_sequence_predictions(self, test_data, data_map, encoder_states, encoder_outputs,
                                                projection_horizon):
        logging.info("Performing multi-step ahead prediction.")
        current_treatments = data_map['current_treatments']
        previous_treatments = data_map['previous_treatments']

        sequence_lengths = test_data['sequence_lengths'] - 1
        num_patient_points = current_treatments.shape[0]

        current_dataset = dict()
        current_dataset['current_covariates'] = torch.zeros((num_patient_points, projection_horizon,
                                                                test_data['current_covariates'].shape[-1])).to(self.device)
        current_dataset['previous_treatments'] = torch.zeros((num_patient_points, projection_horizon,
                                                                 test_data['previous_treatments'].shape[-1])).to(self.device)
        current_dataset['current_treatments'] = torch.zeros((num_patient_points, projection_horizon,
                                                                test_data['current_treatments'].shape[-1])).to(self.device)
        current_dataset['init_state'] = torch.zeros((num_patient_points, encoder_states.shape[-1])).to(self.device)

        predicted_outputs = torch.zeros((num_patient_points, projection_horizon,
                                            test_data['outputs'].shape[-1])).to(self.device)

        for i in range(num_patient_points):
            seq_length = int(sequence_lengths[i])
            current_dataset['init_state'][i] = encoder_states[i, seq_length - 1]
            current_dataset['current_covariates'][i, 0, 0] = encoder_outputs[i, seq_length - 1]
            current_dataset['previous_treatments'][i] = previous_treatments[i,
                                                        seq_length - 1:seq_length + projection_horizon - 1, :]
            current_dataset['current_treatments'][i] = current_treatments[i, seq_length:seq_length + projection_horizon,
                                                       :]

        for t in range(0, projection_horizon):
            print(t)
            predictions = self.get_predictions(current_dataset)
            for i in range(num_patient_points):
                predicted_outputs[i, t] = predictions[i, t]
                if (t < projection_horizon - 1):
                    current_dataset['current_covariates'][i, t + 1, 0] = predictions[i, t, 0]

        test_data['predicted_outcomes'] = predicted_outputs

        return predicted_outputs

    def compute_loss_treatments_one_hot(self, target_treatments, treatment_predictions, active_entries):
        treatment_predictions = treatment_predictions.view(-1, self.max_sequence_length, self.num_treatments)
        cross_entropy_loss = torch.sum(
            (- target_treatments * torch.log(treatment_predictions + 1e-8)) * active_entries) \
                            / torch.sum(active_entries)
        return cross_entropy_loss
    
    def compute_loss_predictions(self, outputs, predictions, active_entries):
        predictions = predictions.view(-1, self.max_sequence_length, self.num_outputs)
        mse_loss = torch.sum(torch.square(outputs - predictions) * active_entries) \
                / torch.sum(active_entries)

        return mse_loss
    
    def evaluate_predictions(self, dataset):
        predictions = self.get_predictions(dataset)
        unscaled_predictions = predictions * dataset['output_stds'] \
                               + dataset['output_means']
        unscaled_predictions = unscaled_predictions.reshape((-1, self.max_sequence_length, self.num_outputs))
        # unscaled_predictions = np.reshape(unscaled_predictions,
        #                                   newshape=(-1, self.max_sequence_length, self.num_outputs))
        unscaled_outputs = dataset['unscaled_outputs'].to(self.device)
        active_entries = dataset['active_entries'].to(self.device)

        mse = self.get_mse_at_follow_up_time(unscaled_predictions, unscaled_outputs, active_entries)
        mean_mse = torch.mean(mse)
        return mean_mse, mse
    
    def get_mse_at_follow_up_time(self, prediction, output, active_entires):
        mses = torch.sum(torch.sum((prediction - output) ** 2 * active_entires, dim=-1), dim=0) \
               / active_entires.sum(dim=0).sum(dim=-1)
        return mses
    
    def get_optimizer(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer
    
    def compute_sequence_length(self, sequence):
        used = torch.sign(torch.max(torch.abs(sequence), dim=2)[0])
        length = torch.sum(used, dim=1)
        length = length.int()
        return length
    
    def save_network(self, model_dir, checkpoint_name):
        torch.save(self.state_dict(), f"{model_dir}/{checkpoint_name}.pth")
        num_params = sum(p.numel() for p in self.parameters())
        logging.info(f"Number of parameters: {num_params}")
        logging.info(f"Model saved to: {model_dir}/{checkpoint_name}.pth")