import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np  
from pytorch_lightning.loggers import TensorBoardLogger

from src.baselines.vcip.auxiliary_model import AuxiliaryModel
from src.baselines.vcip.generative_model import GenerativeModel
from src.baselines.vcip.inference_model import InferenceModel
from src.utils.helper_functions import compute_kl_divergence
from src.data.cip_dataset import CIPDataset, get_dataloader
from src.utils.utils import set_seed
from scipy.stats import spearmanr
from scipy import stats
from src.utils.helper_functions import generate_perturbed_sequences
from src.data.her_data_generator import create_history_treatment_goal_samples, convert_dataloader_to_samples
from src.gift.utils.evaluator import save_evaluation_results
import os
import pickle
import random
import copy
from src.baselines.base_model import BaseCausalModel
from fvcore.nn import FlopCountAnalysis

class VAEModel(pl.LightningModule, BaseCausalModel):
    def __init__(self, config, dataset_collection):
        super(VAEModel, self).__init__()
        self.save_hyperparameters(config)
        self.config = config
        self.dataset_collection = dataset_collection
        self.generative_model = GenerativeModel(config)
        self.inference_model = InferenceModel(config)
        self.auxiliary_model = AuxiliaryModel(config, dataset_collection)
        
        self.lr = config['model']['lr']
        self.tau = config['exp']['tau']
        self.a_dim = config['dataset']['treatment_size']
        self.predict_X = config['dataset']['predict_X']
        self.automatic_optimization = False
        self.validation_step_outputs = []
        self.train_step_outputs = []
        self.count_flops_processed = False
        self.complexity_info = {}
        if "mimic" in config['dataset']['name']:
            self.scale_param = dataset_collection.train_f.scaling_params['output_means']
        elif "tumor" in config['dataset']['name']:
            self.scale_param = dataset_collection.train_scaling_params[1]['cancer_volume']

    def count_flops(self, batch):
        flops = FlopCountAnalysis(self, batch)
        self.batch_size = batch[1]['current_treatments'].shape[0]
        mflops = flops.total() / 1e6 / self.batch_size
        print('FLOPs: {:.2f} MFLOPs'.format(mflops))
        for name, module_flops in flops.by_module().items():
            module_flops_per_sample = module_flops / 1e6 / self.batch_size
        params = sum(p.numel() for p in self.parameters())
        print('Parameters:', params)
        return params, mflops

    def init_AuxiliaryModel(self, config):
        early_stop_callback = pl.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=config['exp']['patience'],
            mode='min')
        logger_board = TensorBoardLogger(
            save_dir=config['exp']['current_dir'], 
            name='', 
            version='')
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            dirpath=None,
            monitor='val_loss',
            filename='checkpoint-{epoch:02d}-{val_loss:.3f}',
            save_top_k=3,
            mode='min')
        trainer = pl.Trainer(
            logger=logger_board,
            max_epochs=config['exp']['epochs'],
            enable_progress_bar=False,
            enable_model_summary=False, 
            devices=config['exp']['gpus'],
            callbacks=[early_stop_callback, checkpoint_callback])
        trainer.fit(self.auxiliary_model)

    def train_dataloader(self) -> DataLoader:
        return get_dataloader(CIPDataset(self.dataset_collection.train_f.data, self.config, train=True), batch_size=self.config['exp']['batch_size'], shuffle=True)
    
    def val_dataloader(self) -> DataLoader:
        return get_dataloader(CIPDataset(self.dataset_collection.val_f.data, self.config), batch_size=self.config['exp']['batch_size'], shuffle=False)

    def forward(self, batch):
        H_t, targets = batch
        Y_targets = targets['outputs']
        X_targets = None if not self.predict_X else targets['vitals']
        a_seq = targets['current_treatments']
        elbo, reg_loss, _ = self.calculate_elbo(H_t, Y_targets, a_seq, X_targets=X_targets, num_samples=self.config.exp.num_samples, optimize_a=False)
        return elbo, reg_loss

    def training_step(self, batch, batch_idx):
        self.train()
        if batch_idx == 0 and not self.count_flops_processed:
            params, mflops = self.count_flops(batch)
            self.complexity_info['params'] = params
            self.complexity_info['mflops'] = mflops
            self.log('params', params)
            self.log('mflops', mflops)
            self.count_flops_processed = True
        
        opt = self.optimizers()
        opt.zero_grad()

        elbo, reg_loss = self(batch)
        self.manual_backward(elbo)
        opt.step()

        self.log('train_loss', elbo)
        self.train_step_outputs.append(reg_loss)
        return elbo

    def validation_step(self, batch, batch_idx):
        elbo, reg_loss = self(batch)
        self.log('val_loss', elbo)
        self.log('val_reg_loss', reg_loss)
        self.validation_step_outputs.append(reg_loss)
        return reg_loss

    def on_validation_epoch_end(self):
        avg_loss = torch.stack([x for x in self.validation_step_outputs]).mean()
        print(f"val_reg_loss loss: {avg_loss.item()}")
        self.validation_step_outputs.clear()
        return

    def on_train_epoch_end(self):
        avg_loss = torch.stack([x for x in self.train_step_outputs]).mean()
        print(f"train_reg_loss loss: {avg_loss.item()}")
        self.train_step_outputs.clear()
        return

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def configure_optimizers(self):
        for name, param in self.named_parameters():
            if not param.requires_grad:
                print(f"Warning: {name} does not require gradients")
        
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.config['exp']['weight_decay'])
        param_groups = optimizer.param_groups[0]['params']
        all_params = set(self.parameters())
        if len(param_groups) != len(list(all_params)):
            print("Warning: Not all parameters are in optimizer")
        print(f"Number of parameters: {len(list(all_params))}")
        
        return optimizer

    def optimize_interventions_onetime(self, encoder=None, num_iterations=100, learning_rate=0.01, batch_size=64):
        set_seed(self.config['exp']['seed'])
        
        if not self.config.exp.test:
            data = self.dataset_collection.val_f.data
        else:
            data = self.dataset_collection.test_f.data
        dataloader = get_dataloader(CIPDataset(data, self.config), batch_size=batch_size, shuffle=False)
        
        device = "cuda"

        self.inference_model.to(device)
        self.generative_model.to(device)
        self.auxiliary_model.to(device)

        self.inference_model.train()
        self.generative_model.train()
        self.auxiliary_model.train()

        results = []
        losses = []
        losses_2 = []
        ture_output_list = []
        output_after_actions_list = []
        ture_output_actions_list = []
        elbos = []
        
        for (i, batch) in enumerate(dataloader):
            print(f"-" * 50 + f"Batch {i}" + "-" * 50)
            H_t, targets = batch
            for key in H_t:
                H_t[key] = H_t[key].to(device)
            for key in targets:
                targets[key] = targets[key].to(device)
            Y_targets = targets['outputs']
            X_targets = None if not self.predict_X else targets['vitals']
            print(f"shape of Y_targets: {Y_targets.shape}")
            batch_size = Y_targets.size(0)
            best_elbo = float('inf')
            best_step = 0

            best_reg = float('inf')
            best_step_reg = 0
            
            a_seq = torch.rand(batch_size, self.tau, self.a_dim, device=device, requires_grad=True)
            best_a_seq = torch.sigmoid(a_seq).clone().detach()
            best_a_seq_reg = a_seq.clone().detach()
            optimizer = torch.optim.Adam([a_seq], lr=learning_rate)

            print(f"-" * 50)
            for iteration in range(num_iterations):
                a_seq_sigmoid = torch.sigmoid(a_seq)
                optimizer.zero_grad()
                elbo, reg_loss, reg_loss_2 = self.calculate_elbo(H_t, Y_targets, a_seq_sigmoid, X_targets=X_targets,
                                                num_samples=self.config.exp.num_samples, 
                                                optimize_a=True)
                loss = elbo
                loss.backward()
                optimizer.step()
                if elbo.item() < best_elbo:
                    best_elbo = elbo.item()
                    best_a_seq = torch.sigmoid(a_seq.clone().detach())
                    elbos.append(best_elbo)
                    best_step = iteration

                if reg_loss.item() < best_reg:
                    best_reg = reg_loss.item()
                    best_step_reg = iteration
                    best_a_seq_reg = torch.sigmoid(a_seq.clone().detach())

            if best_reg < 0:
                best_a_seq = best_a_seq_reg

            optimized_a_seq = best_a_seq
            
            output_after_actions = self.dataset_collection.val_f.simulate_output_after_actions(
                H_t, optimized_a_seq, self.dataset_collection.train_scaling_params)
            ture_output = targets['outputs'][:, -1, :].detach().cpu().numpy()

            true_actions = targets['current_treatments']
            ture_output_actions = self.dataset_collection.val_f.simulate_output_after_actions(
                H_t, true_actions, self.dataset_collection.train_scaling_params)

            loss = np.sqrt(((output_after_actions-ture_output) ** 2).mean())
            losses.append(loss)

            loss_2 = np.sqrt(((ture_output_actions-ture_output) ** 2).mean())
            losses_2.append(loss_2)

            ture_output_list.append(ture_output)
            output_after_actions_list.append(output_after_actions)
            ture_output_actions_list.append(ture_output_actions)

            results.append(optimized_a_seq)

        true_actions = targets['current_treatments']
        ture_output_actions = self.dataset_collection.val_f.simulate_output_after_actions(
            H_t, true_actions, self.dataset_collection.train_scaling_params)

        print(f"Mean loss v1: {sum(losses) / len(losses)}")
        print(f"losses: {losses}")
        print(f"Mean loss_2 v2: {sum(losses_2) / len(losses_2)}")
        
        ture_output_list = np.concatenate(ture_output_list, axis=0)
        output_after_actions_list = np.concatenate(output_after_actions_list, axis=0)
        ture_output_actions_list = np.concatenate(ture_output_actions_list, axis=0)
        print(f"ture_output_list shape: {ture_output_list.shape}")
        print(f"first 10 ture_output_list: {ture_output_list[:10]}")
        
        loss = np.sqrt((((output_after_actions_list-ture_output_list) * self.scale_param) ** 2).mean()) 
        loss_2 = np.sqrt((((ture_output_actions_list-ture_output_list) * self.scale_param) ** 2).mean())
        
        print(f"Mean loss v2: {loss}")
        print(f"Mean loss_2 v2: {loss_2}")

        return torch.cat(results, dim=0), loss

    def get_loss_outcome_after_actions(self, H_t, Y_targets, a_seq):
        output_after_actions = self.dataset_collection.val_f.simulate_output_after_actions(H_t, a_seq, self.dataset_collection.train_scaling_params)
        ture_output = Y_targets[:, -1, :].detach().cpu().numpy()
        loss = np.sqrt(((output_after_actions-ture_output) ** 2).mean())
        loss = F.mse_loss(torch.tensor(output_after_actions, device=Y_targets.device), Y_targets[:, -1, :])
        return loss, output_after_actions

    def calculate_elbo(self, H_t, Y_targets, a_seq, X_targets=None, num_samples=10, optimize_a=False, reduce=True):
        H_rep = self.auxiliary_model.build_representations(H_t, only_last=True)
        
        Z_s_i, loss_hy = self.inference_model.init_hidden_history(H_t)

        if self.config.dataset.treatment_mode == 'multilabel':
            predict_action_loss = self.generative_model.loss_predict_actions_bern(Z_s_i, a_seq)
        else:
            predict_action_loss = self.generative_model.loss_predict_actions_beta(Z_s_i, a_seq)
        self.inference_model.reset_states()
        self.generative_model.reset_states()
        a_seq_hiddens_1 = self.generative_model.build_action_encoding(a_seq)
        
        a_seq_hiddens = self.generative_model.build_reverse_action_encoding(a_seq)
        
        kl_losses = []
        reg_losses = []
        action_losses = []
        try:
            Y_last = Y_targets[:, -1, :]
        except:
            Y_targets = Y_targets.reshape(a_seq.shape[0], 1, -1)
            Y_targets = torch.from_numpy(Y_targets).to(device=a_seq.device) 
            Y_last = Y_targets[:, -1, :]
        reg_loss_2 = None
        
        hidden_states = self.generative_model.build_hidden_states(Z_s_i, a_seq)
        reversed_hidden_states = self.generative_model.build_reverse_hidden_states(hidden_states)

        tau = a_seq.shape[1]
        for s in range(tau):
            a_s = a_seq[:, s, :]
            a_s_hidden = a_seq_hiddens[:, s, :]
            
            if s == tau - 1:
                Z_s_inf = Z_s_i
            else:
                Z_s_inf = reversed_hidden_states[:, s+1, :]
                
            Z_s_inf = torch.cat([Z_s_inf, Z_s_i], dim=-1)
            
            if s < tau - 1:
                a_s_next = a_seq[:, s+1, :]
                if self.config.dataset.treatment_mode == 'multilabel':
                    action_loss = self.generative_model.bern_loss(Z_s_i, a_s_next)
                else:
                    action_loss = self.generative_model.beta_loss(Z_s_i, a_s_next)
                action_losses.append(action_loss)
            q_mu, q_logvar = self.inference_model(Z_s_inf, a_s_hidden, H_rep, Y_last)
            p_mu, p_logvar = self.generative_model(Z_s_i, a_s_hidden)
            Z_samples = self.inference_model.reparameterize_multiple(q_mu, q_logvar, num_samples)
            kl_loss = compute_kl_divergence(q_mu, q_logvar, p_mu, p_logvar)
            kl_losses.append(kl_loss)

            if s > 0 and s < tau - 1:
                Y_current = Y_targets[:, s, :]
                a = a_seq_hiddens_1[:, -1, :]
                reg_loss = self.generative_model.decoding_Y_loss_2(Z_samples.mean(dim=0), Y_current, a)
                if self.predict_X:
                    X_current = X_targets[:, s, :]
                    reg_loss += self.generative_model.decoding_X_loss(Z_samples.mean(dim=0), X_current)
                reg_losses.append(reg_loss)

            if s == tau - 1:
                Y_current = Y_targets[:, s, :]
                a = a_seq_hiddens_1[:, -1, :]
                reg_losses.append(self.generative_model.decoding_Y_loss_2(Z_samples.mean(dim=0), Y_current, a))
            Z_s_i = Z_samples.mean(dim=0)
        kl_loss = torch.cat(kl_losses, dim=-1).mean(dim=-1, keepdim=True)
        if optimize_a:
            reg_loss = reg_losses[-1]
        else:
            reg_loss = torch.cat(reg_losses, dim=-1).mean(dim=-1, keepdim=True)
            if self.config.exp.remove_aux:
                reg_loss = reg_losses[-1]
            reg_loss = reg_losses[-1]

        action_loss = torch.cat(action_losses, dim=-1).mean(dim=-1, keepdim=True) if action_losses else 0
        if optimize_a:
            elbo = reg_loss * self.config.exp.lambda_reg + \
                   kl_loss * self.config.exp.lambda_kl + \
                   action_loss * self.config.exp.lambda_step + \
                   predict_action_loss * self.config.exp.lambda_action
        else:
            elbo  = reg_loss * self.config.exp.lambda_reg + \
                   kl_loss * self.config.exp.lambda_kl + \
                   action_loss * self.config.exp.lambda_step + \
                   predict_action_loss * self.config.exp.lambda_action + \
                   loss_hy * self.config.exp.lambda_hy 
        if reduce:
            elbo = elbo.mean()
        return elbo, reg_loss.mean(), reg_loss_2

    def print_topk_diff(self, a, b, k=5):
        k = min(k, len(a))
        diff = torch.abs(a - b)
        topk = torch.topk(diff, k=k)
        return topk

    def optimize_interventions_discrete(self):
        results = {}
        model = self.config.model.name.split('/')[0]
        self.eval()
        tau = self.tau
        for i in [2, 4, 6, 8]:
            print(f'start predicting results for tau={i} ...')
            self.tau = i
            self.config.exp.tau = i
            print(self.tau, self.config.exp.tau)
            infos = self.optimize_interventions_discrete_onetime()
            results[i] = infos
        final = {model:results}
        save_dir = f'case_infos/{self.config.exp.seed}/{self.config.exp.test}'
        self.tau = tau

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        with open(save_dir + f'/case_infos_{model}.pkl', 'wb') as f:
            pickle.dump(final, f)

        write_str = []
        for key in results:
            rank = np.mean([results[key][i]['true_sequence_rank'] for i in range(len(results[key]))])
            write_str.append(f"{key}: Average rank of true sequences across all individuals: {rank}")
        with open(save_dir + f'/case_infos_{model}.txt', 'w') as f:
            f.write('\n'.join(write_str))

        return final

    def optimize_interventions_discrete_onetime(self, k=100):
        set_seed(self.config['exp']['seed'])
        if not self.config.exp.test:
            data = self.dataset_collection.val_f.data
        else:
            data = self.dataset_collection.test_f.data
        dataloader = get_dataloader(CIPDataset(data, self.config), 
                                batch_size=1, shuffle=False)
        device = "cuda"
        
        self.inference_model.to(device)
        self.generative_model.to(device)
        self.auxiliary_model.to(device)

        self.inference_model.eval()
        self.generative_model.eval()
        self.auxiliary_model.eval()
        
        all_ranks = []
        all_correlation = []  
        case_infos = []
        
        for (i, batch) in enumerate(dataloader):
            if i > 99:
                break
            print(f"-" * 50 + f"Individual {i}" + "-" * 50)
            H_t, targets = batch
            for key in H_t:
                H_t[key] = H_t[key].to(device)
            for key in targets:
                targets[key] = targets[key].to(device)
            Y_targets = targets['outputs']
            X_targets = None if not self.predict_X else targets['vitals']
            true_actions = targets['current_treatments']

            all_sequences = generate_perturbed_sequences(true_actions, k, self.tau, self.a_dim, device, treatment_mode=self.config.dataset.treatment_mode)
            elbos = []
            true_losses = []
            with torch.no_grad():
                for seq in all_sequences:
                    elbo, _, _ = self.calculate_elbo(H_t, Y_targets, seq, X_targets=X_targets,
                                                num_samples=self.config.exp.num_samples,
                                                optimize_a=True)
                    elbos.append(elbo.item())  
                    output_after_actions = self.dataset_collection.val_f.simulate_output_after_actions(
                        H_t, seq, self.dataset_collection.train_scaling_params)
                    true_output = targets['outputs'][:, -1, :].detach().cpu().numpy()
                    true_loss = np.sqrt(((output_after_actions - true_output) ** 2).mean())
                    true_losses.append(true_loss)

            model_losses = np.array(elbos)
            true_losses = np.array(true_losses)

            corr_model_true, _ = spearmanr(model_losses, true_losses)
            
            case_info = {
                'individual_id': i,
                'model_losses': model_losses,
                'true_losses': true_losses,
                'correlations': {
                    'model_true': corr_model_true,
                },
                'true_sequence': true_actions.cpu().numpy(),
                'true_sequence_rank': np.sum(model_losses < model_losses[-1]) + 1
            }
            case_infos.append(case_info)
            
            elbos = np.array(elbos)
            true_losses = np.array(true_losses)

            
            correlation, p_value = spearmanr(elbos, true_losses)
            all_correlation.append(correlation)
            
            true_seq_rank = np.sum(elbos < elbos[-1]) + 1
            all_ranks.append(true_seq_rank)
            
            print(f"Individual {i} - True sequence rank: {true_seq_rank} out of {k}")
            print(f"ELBO for true sequence: {elbos[-1]}")
            print(f"True Loss for true sequence: {true_losses[-1]}")
            print(f"First 10 ELBOs: {elbos[:10]}")
            print(f"First 10 True Losses: {true_losses[:10]}")
            print(f"Rank correlation for this individual: {correlation:.3f} (p-value: {p_value:.3f})")
        all_correlation = [c for c in all_correlation if not np.isnan(c)]
        avg_rank = sum(all_ranks) / len(all_ranks)
        if len(all_correlation) > 0:
            avg_correlation = sum(all_correlation) / len(all_correlation)
        else:
            avg_correlation = 0
        print(f"Average rank of true sequences across all individuals: {avg_rank:.2f} out of {k}")
        print(f"Average rank correlation across all individuals: {avg_correlation:.3f}")
        
        return case_infos
         
    def _candidate_action_objective(self, H_t, goal_tensor, random_action, reduce=False, **kwargs):
        elbo, _, _ = self.calculate_elbo(H_t, goal_tensor, random_action, optimize_a=True, reduce=reduce)
        return elbo

    def _update_patient_history(self, updated_history, action_np, output):
        if 'current_treatments' in updated_history:
            prev_treatments = updated_history['current_treatments']
            seq_len = prev_treatments.shape[1]
            
            new_treatments = np.zeros((prev_treatments.shape[0], seq_len + 1, prev_treatments.shape[2]))
            new_treatments[:, :-1, :] = prev_treatments
            new_treatments[:, -1, :] = action_np  
            
            updated_history['current_treatments'] = new_treatments
        if 'prev_outputs' in updated_history:
            prev_outputs = updated_history['prev_outputs']
            seq_len = prev_outputs.shape[1]
            
            new_outputs = np.zeros((prev_outputs.shape[0], seq_len + 1, prev_outputs.shape[2]))
            new_outputs[:, :-1, :] = prev_outputs
            new_outputs[:, -1, :] = updated_history['outputs'][:, -1, :]
            
            updated_history['prev_outputs'] = new_outputs
        if 'outputs' in updated_history:
            prev_outputs = updated_history['outputs']
            seq_len = prev_outputs.shape[1]
            
            new_outputs = np.zeros((prev_outputs.shape[0], seq_len + 1, prev_outputs.shape[2]))
            new_outputs[:, :-1, :] = prev_outputs
            new_outputs[:, -1, :] = output
            
            updated_history['outputs'] = new_outputs
        if 'prev_treatments' in updated_history:
            prev_treatments = updated_history['prev_treatments']
            seq_len = prev_treatments.shape[1]
            
            new_prev_treatments = np.zeros((prev_treatments.shape[0], seq_len + 1, prev_treatments.shape[2]))
            new_prev_treatments[:, :-1, :] = prev_treatments
            new_prev_treatments[:, -1, :] = updated_history['current_treatments'][:, -2, :]
            
            updated_history['prev_treatments'] = new_prev_treatments
        if 'static_features' in updated_history:
            static_features = updated_history['static_features']
            seq_len = static_features.shape[1]
            
            new_len = 0
            if 'prev_outputs' in updated_history:
                new_len = updated_history['prev_outputs'].shape[1]
            
            if new_len > seq_len:
                new_static = np.zeros((static_features.shape[0], new_len, static_features.shape[2]))
                new_static[:, :seq_len, :] = static_features
                for i in range(seq_len, new_len):
                    new_static[:, i, :] = static_features[:, -1, :]
                    
                updated_history['static_features'] = new_static
        if 'active_entries' in updated_history:
            active_entries = updated_history['active_entries']
            seq_len = active_entries.shape[1]
            
            new_active_entries = np.zeros((active_entries.shape[0], seq_len + 1, active_entries.shape[2]))
            new_active_entries[:, :-1, :] = active_entries
            new_active_entries[:, -1, :] = 1  
            
            updated_history['active_entries'] = new_active_entries
        if 'vitals' in updated_history:
            prev_vitals = updated_history['vitals']
            seq_len = prev_vitals.shape[1]
            
            new_vitals = np.zeros((prev_vitals.shape[0], seq_len + 1, prev_vitals.shape[2]))
            new_vitals[:, :-1, :] = prev_vitals
            new_vitals[:, -1, :] = updated_history['future_vitals'][:, :1, :]
            updated_history['vitals'] = new_vitals
        
        if 'future_vitals' in updated_history:
            updated_history['future_vitals'] = updated_history['future_vitals'][:, 1:, :]
        return updated_history