import copy
import os
import re
import time
import logging
import random
from argparse import Namespace
from pathlib import Path
from copy import deepcopy
import os
import numpy as np
import sklearn as sk
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn as sk
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from experiments.data_eicu import get_eicu_tvt_datasets
from experiments.data_mimic import collate_fn_biclass, collate_fn_extrap, get_m4_tvt_datasets, collate_fn_synthetic
from experiments.data_physionet12 import get_physionet12_tvt_datasets
from experiments.data_physionet12_full import get_physionet12_tvt_datasets_full
from experiments.data_physionet12_full_hour import get_physionet12_tvt_datasets_full_hour
from experiments.data_p12 import get_p12_tvt_datasets
from experiments.data_synthetic import get_data_loaders
from model.model_factory import ModelFactory

from utils import record_experiment


class BaseExperiment:
    ''' Base experiment class '''

    def __init__(self, args: Namespace):
        self.args = args
        self.epochs_max = args.epochs_max
        self.patience = args.patience
        self.proj_path = Path(args.proj_path)
        self.mf = ModelFactory(self.args)
        self.tags = [
            self.args.ml_task,
            self.args.data,
            self.args.leit_model,
            self.args.ivp_solver,
            self.args.test_info]

        self.args.exp_name = '_'.join(
            self.tags + [("r"+str(args.random_state))])

        torch.manual_seed(args.random_state)
        np.random.seed(args.random_state)
        random.seed(args.random_state)

        self._init_logger()
        self.device = torch.device(args.device)
        self.logger.info(f'Device: {self.device}')

        self.variable_num, self.dltrain, self.dlval, self.dltest = self.get_data(args)
        self.model = self.get_model().to(self.device)
        num_params = sum(p.numel() for p in self.model.parameters())
        self.logger.info(f'num_params={num_params}')

        self.optim = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
                                      lr=args.lr, weight_decay=args.weight_decay)

        self.scheduler = None
        if args.lr_scheduler_step > 0:
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optim, args.lr_scheduler_step, args.lr_decay)

    def _init_logger(self):

        logging.basicConfig(filename=self.proj_path / 'log' / (self.args.exp_name+'.log'),
                            filemode='w',
                            level=logging.INFO,
                            force=True)

        self.logger = logging.getLogger()

    def get_model(self):
        print('current_ehr_variable_num: ' + str(self.variable_num))
        if self.args.ml_task == 'extrap':
            model = self.mf.initialize_extrap_model()
        elif self.args.ml_task == 'biclass':
            model = self.mf.initialize_biclass_model()
        elif self.args.ml_task == 'synthetic':
            model = self.mf.initialize_synthetic_model()
        else:
            raise ValueError("Unknown")

        return model

    def get_data(self, args):
        self.args.num_times = self.args.time_max + 1
        if re.match("m4", self.args.data):
            train_data, val_data, test_data = get_m4_tvt_datasets(
                self.args, self.proj_path, self.logger)
        elif self.args.data == "physionet12":
            train_data, val_data, test_data = get_physionet12_tvt_datasets(
                self.args, self.proj_path, self.logger)
        elif self.args.data == "physionet12_full":
            train_data, val_data, test_data = get_physionet12_tvt_datasets_full(
                self.args, self.proj_path, self.logger)
        elif re.match("physionet12_full_hour", self.args.data):
            train_data, val_data, test_data = get_physionet12_tvt_datasets_full_hour(
                self.args, self.proj_path, self.logger)
        elif re.match("p12", self.args.data):
            train_data, val_data, test_data = get_p12_tvt_datasets(
                self.args, self.proj_path, self.logger)
        elif re.match("eicu", self.args.data):
            train_data, val_data, test_data = get_eicu_tvt_datasets(
                self.args, self.proj_path, self.logger)
        elif self.args.experiment == 'synthetic':
            _args = {'n_ts': args.n_ts, 'training_scheme': args.training_scheme, 'dag_data': args.dag_data, 'data': args.data, 'seed':args.random_state}
            self.dim, self.n_classes,train_data, val_data, test_data = get_data_loaders(args.data, _args, args.batch_size)
            return self.dim, train_data, val_data, test_data        
        else:
            raise ValueError("Unsupported Dataset!")
        if self.args.ml_task == 'extrap':
            collate_fn = collate_fn_extrap
        elif self.args.ml_task == 'biclass':
            collate_fn = collate_fn_biclass
        elif self.args.ml_task == 'synthetic':
            collate_fn = collate_fn_synthetic
        else:
            raise ValueError("Unknown")

        dl_train = DataLoader(
            dataset=train_data,
            collate_fn=lambda batch: collate_fn(
                batch, train_data.variable_num, self.args),
            shuffle=True,
            batch_size=self.args.batch_size)
        dl_val = DataLoader(
            dataset=val_data,
            collate_fn=lambda batch: collate_fn(
                batch, val_data.variable_num, self.args),
            shuffle=True,
            batch_size=self.args.batch_size)
        dl_test = DataLoader(
            dataset=test_data,
            collate_fn=lambda batch: collate_fn(
                batch, test_data.variable_num, self.args),
            shuffle=True,
            batch_size=self.args.batch_size)

        return train_data.variable_num, dl_train, dl_val, dl_test

    def training_step(self, batch):
        results = self.model.compute_prediction_results(batch)
        return results['loss']

    def validation_step(self) -> Tensor:
        raise NotImplementedError

    def test_step(self) -> Tensor:
        raise NotImplementedError

    def compute_results_all_batches(self, dl):
        total = {}
        total['loss'] = 0
        total['likelihood'] = 0
        total['mse'] = 0
        total["auroc"] = 0
        total['kl_first_p'] = 0
        total['std_first_p'] = 0
        total['ce_loss'] = 0
        total['mse_reg'] = 0
        total['mae_reg'] = 0
        total['mse_extrap'] = 0
        total['forward_time'] = 0
        total['kldiv_z0'] = 0
        total['loss_ae'] = 0
        total['loss_vae'] = 0
        total['loss_ll_z'] = 0
        total["val_loss"] = 0
        total["lat_variance"] = 0

        n_test_batches = 0


        classif_predictions = torch.Tensor([]).to(self.args.device)
        all_test_labels = torch.Tensor([]).to(self.args.device)

       
        vis_feat_list = []
        n_traj_samples = self.args.k_iwae

        for batch in dl:
            results = self.model.run_validation(batch)
            vis_t = batch['truth'] 
            if isinstance(vis_t, torch.Tensor):
                vis_t = vis_t.detach().cpu().numpy()
            if vis_t.ndim == 1:
                vis_t = vis_t[:, None]  
            elif vis_t.ndim > 2:
                B = vis_t.shape[0]
                vis_t = vis_t.reshape(B, -1)
            vis_feat_list.append(vis_t)       

            if self.args.ml_task == 'biclass':
                n_labels = 1
                classif_predictions = torch.cat(
                    (classif_predictions, results["label_predictions"].reshape(n_traj_samples, -1, n_labels)),
                    dim=1
                )
                all_test_labels = torch.cat(
                    (all_test_labels, batch['truth'].reshape(-1, n_labels)),
                    dim=0
                )

            for key in total.keys():
                if results.get(key) is not None:
                    var = results[key]
                    if isinstance(var, torch.Tensor):
                        var = var.detach()
                    total[key] += var
            n_test_batches += 1

        if n_test_batches > 0:
            for key, _ in total.items():
                total[key] = total[key] / n_test_batches

        if self.args.ml_task == 'biclass':
            all_test_labels = all_test_labels.repeat(n_traj_samples, 1, 1)
            total["auroc"] = 0.0
            total["auprc"] = 0.0
            if torch.sum(all_test_labels) != 0.0:
                print("Number of labeled examples: {}".format(
                    int(len(all_test_labels.reshape(-1)) / n_traj_samples)))
                print("Number of examples with mortality 1: {}".format(
                    int(torch.sum(all_test_labels == 1.0) / n_traj_samples)))

                array_truth = all_test_labels.cpu().numpy().reshape(-1)
                array_predict = classif_predictions.cpu().numpy().reshape(-1)

                total["auroc"] = sk.metrics.roc_auc_score(array_truth, array_predict)
                total["auprc"] = sk.metrics.average_precision_score(array_truth, array_predict)
            else:
                print("Warning: Couldn't compute AUC -- all examples are from the same class")

          
            try:
                if len(vis_feat_list) == 0:
                else:
                    import warnings
                    warnings.filterwarnings("ignore", category=UserWarning)

                  
                    E = np.concatenate(vis_feat_list, axis=0)

                  
                    N = E.shape[0]
                    pred_k_N = classif_predictions.detach().cpu().numpy().reshape(n_traj_samples, -1, 1)
                    if pred_k_N.shape[1] != N:
            
                        array_predict_mean = array_predict[:N]
                    else:
                        array_predict_mean = pred_k_N.mean(axis=0).reshape(-1) 

                    imp = SimpleImputer(strategy="median")
                    E_imp = imp.fit_transform(E)
                    E_std = StandardScaler().fit_transform(E_imp)
                    pca = PCA(n_components=min(50, E_std.shape[1]), random_state=0)
                    E_pca = pca.fit_transform(E_std)
                    tsne = TSNE(
                        n_components=2,
                        perplexity=30,
                        learning_rate=max(200, E_pca.shape[0] // 12),
                        n_iter=750,  
                        init='pca',
                        random_state=0,
                        method='barnes_hut'
                    )
                    Z = tsne.fit_transform(E_pca)

                    mpl.rcParams['text.usetex'] = False
                    mpl.rcParams['mathtext.default'] = 'regular'
                    mpl.rcParams['font.family'] = 'DejaVu Sans'

                    fig, ax = plt.subplots(figsize=(6, 5))
                    sc = ax.scatter(Z[:, 0], Z[:, 1], c=array_predict_mean, s=10, cmap='viridis')
                    ax.set_title("t-SNE of model features — colored by predicted probability")
                    ax.set_xlabel("t-SNE 1"); ax.set_ylabel("t-SNE 2")
                    cbar = plt.colorbar(sc, ax=ax)
                    cbar.set_label("Predicted probability (array_predict mean)")

                    save_dir = getattr(self.args, "save_dir", ".")
                    os.makedirs(save_dir, exist_ok=True)
                    fig_path = os.path.join(save_dir, "tsne_predictions.png")
                    plt.tight_layout()
                    plt.savefig(fig_path, dpi=300)
                    plt.close(fig)

            except Exception as e:
                pass

        return total

    def run(self) -> None:
        best_loss = float('inf')
        waiting = 0
        durations = []
        best_model = deepcopy(self.model.state_dict())
        for epoch in range(1, self.epochs_max):
            iteration = 1
            self.model.train()
            start_time = time.time()

            for batch in self.dltrain:
                self.optim.zero_grad()
                train_loss = self.training_step(batch)
                train_loss.backward()
                if self.args.clip_gradient:
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.args.clip)
                self.optim.step()

                self.logger.info(
                    f'[epoch={epoch:04d}|iter={iteration:04d}] train_loss={train_loss:.5f}')
                iteration += 1

            epoch_duration = time.time() - start_time
            durations.append(epoch_duration)
            self.logger.info(
                f'[epoch={epoch:04d}] epoch_duration={epoch_duration:5f}')

            self.model.eval()
            val_loss = self.validation_step(epoch)
            self.logger.info(f'[epoch={epoch:04d}] val_loss={val_loss:.5f}')

            if self.scheduler:
                self.scheduler.step()

            if val_loss < best_loss:
                best_loss = val_loss
                best_model = deepcopy(self.model.state_dict())
                waiting = 0
            else:
                waiting += 1

            if waiting >= self.patience:
                break

        self.model.load_state_dict(best_model)
        test_loss = self.test_step()

        self.logger.info(f'epoch_duration_mean={np.mean(durations):.5f}')
        self.logger.info(f'test_loss={test_loss:.5f}')
    def finish(self):
        record_experiment(self.args, self.model)
        torch.save(self.model.state_dict(), self.proj_path /
                   'temp/model' / (self.args.exp_name+'.pt'))
        logging.shutdown()


