#!/usr/bin/env python
import argparse
import sys
sys.path.insert(0,"../")

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger, MLFlowLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor, Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
#from lightning.pytorch import seed_everything
from pytorch_lightning import seed_everything
import numpy as np

import os

import buildFlow#, cv_data_utils, covid_data_utils
from data_loaders import tumor_data_utils #, data_utils, moabb_data_utils, robert_data_utils 
from buildFlow import cnfModule
import models
#from azureml.core.run import Run
import torch
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset, TensorDataset
from data_loaders.tumor_data_utils import transform_to_our_input
from tumor.data_utils import data_to_torch_tensor, process_data , process_counterfactual_seq_test_data, data_to_torch_tensor_multistep
#from buildFlow import startState
import pickle
import logging
from plotters import plot_trajectories
import wandb

#python train_model.py --seed=44 --max_epochs=2


def compute_norm_mse_loss(ground_truth_outputs, predictions, active_entries, norm=1):
    """
    Computes normed MSE Loss

    Args:
    outputs (torch.tensor): list of true outputs (ground_truth)
    predictions (torch.tensor): list of model predictions
    active_entries (torch.tensor): list of active entries
    norm (int): normalization constant

    Returns:
    mse_loss (float): normed mse loss value
    """
    mse_loss = torch.mean(
        (ground_truth_outputs - (predictions) / norm).pow(2) * active_entries,
    )
    return mse_loss


def multistep_predict( model, data_path, max_horizon, trainer, args={}):
    #wandb.init()
    coeff = args.coeff
    kappa = args.kappa
    transformed_datapath = f"../{data_path}/new_cancer_sim_{coeff}_{coeff}_kappa_{kappa}.p" #f"../tumorData/transformed/new_cancer_sim_{coeff}_{coeff}_kappa_{kappa}.p"
    print(f"Loading transformed data from {transformed_datapath}")
    pickle_map = pickle.load(open(transformed_datapath, "rb"))

    _, _, test_data = process_data(pickle_map)

    device_type = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Predicting with: {device_type}")


    device = torch.device(device_type)
    

    sample_proportion=1

    test_X, test_A, test_Time, test_y, test_treat, treated_indices, untreated_indices, test_mask = data_to_torch_tensor(
        test_data,
        sample_prop=sample_proportion,
    )

    batch_size= test_X.shape[0]

    #data_ind = get_observation_mask(test_data)
    data_o, data_c, data_Time = transform_to_our_input(test_X, test_A)
    data_o = data_o.to(device)
    data_Time = data_Time.to(device)

    data_c = torch.cat((data_c, test_mask[:,:58,:]),dim=-1).type('torch.FloatTensor').to(device)

    #generate initial states
    state0 = model.startState(data_o[:,0,:], batch_size).type('torch.FloatTensor').to(device)
    #create dataloaders
    data_set = TensorDataset(data_o, data_c, data_Time, test_mask[:,:58,:].type('torch.FloatTensor').to(device), state0)
    data_set = DataLoader(data_set, batch_size=batch_size)

    #predict state space
    _out = trainer.predict(model=model , dataloaders=data_set)[0] #ckpt_path=ckpt_path,
    states = _out[1].to(device)
    preds = _out[0].to(device)
    pred_y_test= preds[:,:,0].to(device)


    outcomes_test = torch.tensor(test_y[:, :, 0], dtype=torch.float, device=device)[:,:-1]
    active_entries_test = torch.tensor(
        test_y[:, :, 1],
        dtype=torch.float,
        device=device,
    )[:,:-1]
    
    test_mask = test_mask.to(device) #our model implementation does not predict one ahead
    # compute norm mse loss - outcomes loss
    loss_y_test = compute_norm_mse_loss(
        outcomes_test,
        pred_y_test,
        active_entries_test,
    )


    loss_y_treated = compute_norm_mse_loss(
        outcomes_test[treated_indices]* test_data['output_stds'] + test_data['output_means'],
        pred_y_test[treated_indices]* test_data['output_stds'] + test_data['output_means'],
        test_mask[treated_indices,1:-1,0]*active_entries_test[treated_indices],
    )

    loss_y_untreated = compute_norm_mse_loss(
        outcomes_test[untreated_indices] * test_data['output_stds'] + test_data['output_means'],
        pred_y_test[untreated_indices] * test_data['output_stds'] + test_data['output_means'],
        test_mask[untreated_indices,1:-1,0]*active_entries_test[untreated_indices],
    )

    loss_y_test_obs = compute_norm_mse_loss(
            outcomes_test * test_data['output_stds'] + test_data['output_means'] ,
            pred_y_test * test_data['output_stds'] + test_data['output_means'],
            test_mask[:,1:-1,0]*active_entries_test,
        )
    wandb.log(
            {
                "RMSE Test Outcome loss": np.sqrt(loss_y_test.item()),
                "RMSE Test TREATED loss": np.sqrt(loss_y_treated.item()),
                "RMSE Test UNTREATED loss": np.sqrt(loss_y_untreated.item()),
                "Test loss obs": np.sqrt(loss_y_test_obs.item()),
            },
        )

    
    interesting_mask=(torch.sum(data_ind[:,25:], dim=1)>10)[:,0]

    plot_trajectories(data_o[interesting_mask,:,:],states[interesting_mask,:,-1:],data_Time, 'tumor', chart_type = "test_int", logger=None )
 
    

    patient_types= test_X[:,-5:,-1]

    logging.info("Processing counterfactual data...")
    processed_data_multi = process_counterfactual_seq_test_data(
        data_map=test_data,
        states=states.cpu().detach().numpy(),
        projection_horizon=max_horizon,
    )

    test_X, test_A, test_Time, test_y, test_treat, test_mask = data_to_torch_tensor_multistep(
        processed_data_multi,
        sample_prop=sample_proportion,
        max_horizon=max_horizon,
    )
    #np.full(max_dims, np.nan, dtype=np.float32)
    # replace missing obs with nans

    fill_value = -test_data['output_means']/test_data['output_stds'] #np.nan
    state0 = test_X.type('torch.FloatTensor').to(device)
    test_X = torch.full_like(test_A[:,:,-2:], fill_value).to(device)

    test_X[:,0,0] = state0[:,0,-1]
    test_X[:,:,1] = patient_types

    #data_ind = get_observation_mask(processed_data_multi, len = max_horizon)#test_data)
    data_o, data_c, data_Time = transform_to_our_input(test_X.to(device), test_A.to(device), len = max_horizon)
    data_c = torch.cat((data_c, test_mask[:,:-1, :]),dim=-1).type('torch.FloatTensor').to(device)
    data_o = data_o.to(device)
    data_Time = data_Time.to(device)



    outcomes_test = torch.tensor(test_y[:, :, 0], dtype=torch.float, device=device) 

    active_entries_test = torch.tensor(
            test_y[:, :, 1],
            dtype=torch.float,
            device=device,
        )
    current_treatment_test = torch.tensor(
        test_treat[:, -1, :],
        dtype=torch.float,
        device=device,
    )
    observed_entries_test = torch.tensor(
            test_mask[:,:58,:0],
            dtype=torch.float,
            device=device,
        )
    test_mask=test_mask.to(device)
    # Counterfactual prediction

    #create dataloaders
    data_set = TensorDataset(data_o, data_c, data_Time, test_mask.type('torch.FloatTensor'), state0[:,0,:])
    data_set = DataLoader(data_set, batch_size=batch_size)


    #predict state space
    pred_y_test = trainer.predict(model=model , dataloaders=data_set)[0][0]

    #interesting_mask=(torch.sum(data_ind, dim=1)>3)[:,0]
    #print(interesting_mask.shape)
    #print(interesting_mask)
    #print(outcomes_test[interesting_mask,:,None][0])
    #print(pred_y_test[interesting_mask,:,None][0])
    #print(data_ind[interesting_mask])

    #plot_trajectories(outcomes_test[interesting_mask,:,None], pred_y_test[interesting_mask,:,None], data_Time, 'tumor', chart_type = "val", logger=None )
    pred_y_test= pred_y_test[:,:,0].to(device)

    #pred_y_test, pred_a_test, pred_a = self.pred_auto(
    #    test_coeffs,
    #    device=device,
    #    max_horizon=max_horizon,
    #)

    loss_y_test = compute_norm_mse_loss(
        outcomes_test[:, max_horizon - 1],
        pred_y_test[:, max_horizon - 1],
        test_mask[:, max_horizon, 0] *active_entries_test[:, max_horizon - 1],
    )
    #print('lossing')
    #print(loss_y_test)

    loss_y_obs_test = compute_norm_mse_loss(
        outcomes_test[:, max_horizon - 1] * test_data['output_stds'] + test_data['output_means'],
        pred_y_test[:, max_horizon - 1] * test_data['output_stds'] + test_data['output_means'],
        test_mask[:, max_horizon, 0] * active_entries_test[:, max_horizon - 1],
    )

    


    rmses = []
    obs_rmses =[]
    #print('DEBUG')
    #print(outcomes_test.shape)
    #print(pred_y_test.shape)
    #print(active_entries_test.shape)
    rmses.append(
        torch.sqrt(
            compute_norm_mse_loss(outcomes_test, pred_y_test, test_mask[:,1:,0]* active_entries_test),
        ).item(),
    )
    obs_rmses.append(
            torch.sqrt(
                compute_norm_mse_loss(
                    outcomes_test * test_data['output_stds'] + test_data['output_means'], 
                    pred_y_test * test_data['output_stds'] + test_data['output_means'], 
                    test_mask[:,1:,0]* active_entries_test),
            ).item(),
        )
    #obs_rmses.append(
    #    torch.sqrt(
    #        compute_norm_mse_loss(outcomes_test, pred_y_test, observed_entries_test),
    #    ).item(),
    #)

    for i in range(max_horizon):
        rmses.append(
            torch.sqrt(
                compute_norm_mse_loss(
                    outcomes_test[:, i],
                    pred_y_test[:, i],
                    test_mask[:, i+1, 0] * active_entries_test[:, i],
                ),
            ).item(),
        )
        obs_rmses.append(
                torch.sqrt(
                    compute_norm_mse_loss(
                        outcomes_test[:, i] * test_data['output_stds'] + test_data['output_means'],
                        pred_y_test[:, i] * test_data['output_stds'] + test_data['output_means'],
                        test_mask[:, i+1, 0] * active_entries_test[:, i],
                    ),
                ).item(),
            )
        #obs_rmses.append(
        #    torch.sqrt(
        #        compute_norm_mse_loss(
        #            outcomes_test[:, i],
        #            pred_y_test[:, i],
        #            observed_entries_test[:, i],
        #        ),
        #    ).item(),
        #)
    
    wandb.log(
        {
            "RMSE MULTISTEP OBS Test Outcome loss": np.sqrt(loss_y_obs_test.item()),
        },
    )
    wandb.log(
        {
            "RMSE MULTISTEP Test Outcome loss": np.sqrt(loss_y_test.item()),
        },
    )
    

    # RMSEs at other interval
    for i in range(len(rmses) - 1):
        wandb.log({f"RMSE @ {i+1} ": rmses[::-1][i]})
        wandb.log({f"RMSE OBS @ {i+1} ": obs_rmses[::-1][i]})




if __name__ == '__main__':
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--max_epochs', type=int, default=1000) #250
    parser.add_argument('--gpu', default=1, type=int)
    parser.add_argument('--model', default = "cnf", type = str)
    parser.add_argument('--dataset_name', default = "tumor", type = str, help = "dataset to train on")
    parser.add_argument('--entity', default = "edebrouwer", type = str, help = "name of the wandb logger entity")
    parser.add_argument('--ckpt_name', default = "2lhta0c2", type = str, help = "name of the wandb logger entity")

    partial_args, _ = parser.parse_known_args()
    print('main not defined')
    parser = model_cls.add_model_specific_args(parser)
    parser = dataset_cls.add_dataset_specific_args(parser)
    args = parser.parse_args()



