# ---- Imports ---- #
import os
import torch

from scipy.special import gammaln
import logging
logger = logging.getLogger(__name__)

# evaluate co-bps
def neg_log_likelihood(rates, spikes, zero_warning=True):
    """Calculates Poisson negative log likelihood given rates and spikes.
    formula: -log(e^(-r) / n! * r^n)
           = r - n*log(r) + log(n!)

    Parameters
    ----------
    rates : np.ndarray
        numpy array containing rate predictions
    spikes : np.ndarray
        numpy array containing true spike counts
    zero_warning : bool, optional
        Whether to print out warning about 0 rate
        predictions or not

    Returns
    -------
    float
        Total negative log-likelihood of the data
    """
    assert (
        spikes.shape == rates.shape
    ), f"neg_log_likelihood: Rates and spikes should be of the same shape. spikes: {spikes.shape}, rates: {rates.shape}"

    if np.any(np.isnan(spikes)):
        mask = np.isnan(spikes)
        rates = rates[~mask]
        spikes = spikes[~mask]

    assert not np.any(np.isnan(rates)), "neg_log_likelihood: NaN rate predictions found"

    assert np.all(rates >= 0), "neg_log_likelihood: Negative rate predictions found"
    if np.any(rates == 0):
        if zero_warning:
            logger.warning(
                "neg_log_likelihood: Zero rate predictions found. Replacing zeros with 1e-9"
            )
        rates[rates == 0] = 1e-9

    result = rates - spikes * np.log(rates) + gammaln(spikes + 1.0)
    return np.sum(result)


def bits_per_spike(rates, spikes):
    """Computes bits per spike of rate predictions given spikes.
    Bits per spike is equal to the difference between the log-likelihoods (in base 2)
    of the rate predictions and the null model (i.e. predicting mean firing rate of each neuron)
    divided by the total number of spikes.

    Parameters
    ----------
    rates : np.ndarray
        3d numpy array containing rate predictions
    spikes : np.ndarray
        3d numpy array containing true spike counts

    Returns
    -------
    float
        Bits per spike of rate predictions
    """
    nll_model = neg_log_likelihood(rates, spikes)
    null_rates = np.tile(
        np.nanmean(spikes, axis=tuple(range(spikes.ndim - 1)), keepdims=True),
        spikes.shape[:-1] + (1,),
    )
    nll_null = neg_log_likelihood(null_rates, spikes, zero_warning=False)
    return (nll_null - nll_model) / np.nansum(spikes) / np.log(2)

 # ---- Helper funcs ---- #
from ldns.data.latent_attractor import get_attractor_dataloaders

#--- compute mean and std ----#
def mean_std_pure(data):
    n = len(data)
    mean = sum(data) / n
    variance = sum((x - mean) ** 2 for x in data) / n
    std_dev = variance ** 0.5
    return mean, std_dev

import numpy as np 
import random
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import r2_score

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def load_data_loader(spikes, rates, latents, window_size, is_shuffle=True, is_batch=True):
    rates_win, latents_win, spike_win = None, None, None
    for i in range(rates.shape[0]):
        for j in range(window_size-1, rates.shape[1]):
            if rates_win is None:
                rates_win = rates[i:i+1, j-window_size+1:j+1, :]
                latents_win = latents[i:i+1, j:j+1, :]
                spike_win = spikes[i:i+1, j-window_size+1:j+1, :]
            else:
                rates_win = torch.cat((rates_win, rates[i:i+1, j-window_size+1:j+1, :]), axis=0)
                latents_win = torch.cat((latents_win, latents[i:i+1, j:j+1, :]), axis=0)
                spike_win = torch.cat((spike_win, spikes[i:i+1, j-window_size+1:j+1, :]), axis=0)

    dataset = TensorDataset(spike_win, rates_win, latents_win)
    if is_batch:
        dataloader = DataLoader(dataset, 
                                batch_size=128,
                                shuffle=is_shuffle,)
    else:
        dataloader = DataLoader(dataset, 
                                batch_size=rates_win.shape[0],
                                shuffle=is_shuffle,)
    return dataloader

if __name__ == "__main__":
    mean_rate_list = [0.05]

    final_r2_score = []

    gen_data_flag = False

    if gen_data_flag:
        for mean_rate in mean_rate_list:
            # ---- Run params ---- #
            # make sure to match these to the config!
            # these are the default values in the paper
            system_name = "Lorenz"
            signal_length = 32 # length of each sequence
            total_in = 96  # number of neurons (48 held-in, 48 held-out)
            C_in = total_in // 2  # number of input channels (neurons)
            n_ic = 600  # number of initial conditions (total sequences)
            # mean_rate = 0.3  # mean firing rate in Hz
            split_frac_train = 0.7  # fraction of data for training
            split_frac_val = 0.1  # fraction of data for validation
            random_seed = 42  # for reproducibility
            softplus_beta = 2.0  # controls sharpness of rate nonlinearity

            # ---- Generate synthetic data ---- #
            # create dataloaders for train/val/test splits
            train_dataloader, val_dataloader, test_dataloader, dataset = get_attractor_dataloaders(
                system_name=system_name,
                n_neurons=total_in,
                sequence_length=signal_length,
                # noise_std=0.05,
                n_ic=n_ic,
                mean_spike_count=mean_rate * signal_length,
                train_frac=split_frac_train,
                valid_frac=split_frac_val,  # test is 1 - train - valid
                random_seed=random_seed,
                batch_size=1,
                softplus_beta=softplus_beta,
            )

            # ---- Extract data from dataloaders ---- #
            # extract spikes (shape: [batch, time, neurons])

            train_spikes = torch.stack(
                [train_dataloader.dataset[i]["signal"] for i in range(len(train_dataloader.dataset))]
            ).permute(0, 2, 1)

            val_spikes = torch.stack(
                [val_dataloader.dataset[i]["signal"] for i in range(len(val_dataloader.dataset))]
            ).permute(0, 2, 1)

            test_spikes = torch.stack(
                [test_dataloader.dataset[i]["signal"] for i in range(len(test_dataloader.dataset))]
            ).permute(0, 2, 1)

            # extract rates (shape: [batch, time, neurons])
            train_rates = torch.stack(
                [train_dataloader.dataset[i]["rates"] for i in range(len(train_dataloader.dataset))]
            ).permute(0, 2, 1)

            val_rates = torch.stack(
                [val_dataloader.dataset[i]["rates"] for i in range(len(val_dataloader.dataset))]
            ).permute(0, 2, 1)

            test_rates = torch.stack(
                [test_dataloader.dataset[i]["rates"] for i in range(len(test_dataloader.dataset))]
            ).permute(0, 2, 1)

            # extract latents (shape: [batch, time, latent_dim])
            train_latents = torch.stack(
                [train_dataloader.dataset[i]["latents"] for i in range(len(train_dataloader.dataset))]
            ).permute(0, 2, 1)

            val_latents = torch.stack(
                [val_dataloader.dataset[i]["latents"] for i in range(len(val_dataloader.dataset))]
            ).permute(0, 2, 1)

            test_latents = torch.stack(
                [test_dataloader.dataset[i]["latents"] for i in range(len(test_dataloader.dataset))]
            ).permute(0, 2, 1)

            # print data shapes for verification
            print(f"Train data shape: {train_spikes.shape}")
            print(f"Valid data shape: {val_spikes.shape}")
            print(f"Test data shape: {test_spikes.shape}")
            print(f"Train rates shape: {train_rates.shape}")
            print(f"Train latents shape: {train_latents.shape}")

            # win_size_list = [3, 4, 5]
            # final_r2_score = []

        # for win_size in win_size_list:
            # save original simulation data
            simu_data_dir = './data/simulation'
            simu_data_dir = os.path.join(simu_data_dir, 'mean_{}'.format(mean_rate))
            if not os.path.exists(simu_data_dir):
                os.makedirs(simu_data_dir)
            train_data_file = os.path.join(simu_data_dir, 'train_data.pt')
            torch.save({
                "train_spikes": train_spikes,
                "train_rates": train_rates,
                "train_latents": train_latents,
            }, train_data_file)
            val_data_file = os.path.join(simu_data_dir, 'val_data.pt')
            torch.save({
                "val_spikes": val_spikes,
                "val_rates": val_rates,
                "val_latents": val_latents,
            }, val_data_file)
            test_data_file = os.path.join(simu_data_dir, 'test_data.pt')
            torch.save({
                "test_spikes": test_spikes,
                "test_rates": test_rates,
                "test_latents": test_latents,
            }, test_data_file)

    mean_rate_list = [0.05]
    for mean_rate in mean_rate_list:
        if not gen_data_flag:
            simu_data_dir = './data/simulation'
            simu_data_dir = os.path.join(simu_data_dir, 'mean_{}'.format(mean_rate))
            if not os.path.exists(simu_data_dir):
                raise FileNotFoundError("Simulation data directory does not exist: {}".format(simu_data_dir))
            train_data_file = os.path.join(simu_data_dir, 'train_data.pt')
            val_data_file = os.path.join(simu_data_dir, 'val_data.pt')
            test_data_file = os.path.join(simu_data_dir, 'test_data.pt')

            train_spikes = torch.load(train_data_file)['train_spikes']
            train_rates = torch.load(train_data_file)['train_rates']
            train_latents = torch.load(train_data_file)['train_latents']

            val_spikes = torch.load(val_data_file)['val_spikes']
            val_rates = torch.load(val_data_file)['val_rates']
            val_latents = torch.load(val_data_file)['val_latents']

            test_spikes = torch.load(test_data_file)['test_spikes']
            test_rates = torch.load(test_data_file)['test_rates']
            test_latents = torch.load(test_data_file)['test_latents']

        win_size = 5
        train_dataloader = load_data_loader(train_spikes, train_rates, train_latents, win_size, is_shuffle=False)
        val_dataloader = load_data_loader(val_spikes, val_rates, val_latents, win_size, is_shuffle=False, is_batch=False)
        test_dataloader = load_data_loader(test_spikes, test_rates, test_latents, win_size , is_shuffle=False, is_batch=False)

        from model.vanilla_iTransfomer import ConditionModel
        from config.model_config import ModelConfig
        from flow.models.SiT_models import SiT

        from flow.transport.transport import create_transport, Sampler

        device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device('cpu')
        seed_list = [0, 1, 2, 3, 4]
        invert_flag = False
        training_step = 400
        
        # test_r2_tmp = [win_size]
        test_r2_tmp, test_co_bps_tmp = ['latent r2'], ['co-bps']
        final_r2_score = []
        for seed in seed_list:
            setup_seed(seed)

            # vanilla Transformer (encoder-only)
            # context_size = window_size*2 if fourier_flag else window_size
            context_size = win_size
            # context_size = signal_length
            n_chan  = train_rates[0].shape[1]//2 if not invert_flag else context_size
            seq_len = context_size if not invert_flag else train_rates[0].shape[1]
            configs = ModelConfig(
                seq_len=seq_len,
                # pred_len=signal_length,
                enc_in=n_chan,
                training_step=training_step,
                e_layers=2,
                factor=1,
                n_heads=8,
                # d_model=train_latents.shape[-1],
                d_model=n_chan, 
            )
            transformer_model = ConditionModel(configs)

            # SiT model settings
            flow_model = SiT(
                in_channels=n_chan if not invert_flag else seq_len,
                window_size=1,
                hidden_size=n_chan,
                out_dim=train_latents.shape[-1],
                diff_dim=n_chan,
                depth=5,
                mlp_ratio=2.0,
                num_heads=8,
                model_config=configs,
                target_latent_config=configs,
                invert_flag=invert_flag,
            )
            model_fn = flow_model.forward
            flow_model.to(device)

            # set optimizer
            optimizer = torch.optim.Adam(flow_model.parameters(), lr=2e-3, weight_decay=1e-5)

            transport = create_transport(
                path_type="Linear",
                prediction="velocity",
                loss_weight=None,
                train_eps=None,
                sample_eps=1e-1,
            ) # default: velocity
            transport_sampler = Sampler(transport)

            training_step = 400
            train_loss = []
            best_test_r2 = -1.0
            for global_step in range(training_step):  
                for batch_idx, (train_batch_spikes, train_batch_rates, train_batch_latents) in enumerate(train_dataloader):
                    flow_model.train()

                    train_batch_spikes = train_batch_spikes.clone().detach().to(device)
                    train_batch_rates = train_batch_rates.clone().detach().to(device)
                    train_batch_latents = train_batch_latents.clone().detach().to(device)

                    with torch.no_grad():
                        exp_z_manifold = flow_model.linear_encoder(train_batch_latents)

                    model_kwargs = dict(y=train_batch_rates[:, :, :n_chan])
                    loss_dict = transport.training_losses(flow_model, exp_z_manifold, model_kwargs)
                    loss = loss_dict["loss"].mean()

                    mse_func = torch.nn.MSELoss()
                    nll_func = torch.nn.PoissonNLLLoss()

                    # reconstruction loss
                    # noisy latent features
                    pred_rates = torch.squeeze(flow_model.reconstruction_decoder(exp_z_manifold))
                    loss += nll_func(pred_rates, train_batch_rates[:, -1, n_chan:])
                    train_loss.append(loss.item()) 

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    print(f'Epoch: {global_step+1} | Batch: {batch_idx+1} | Loss: {loss.item():.4f}')
                if (global_step+1) % 20 == 0:
                    flow_model.eval()
                    sample_fn = transport_sampler.sample_ode(num_steps=2, sampling_method="euler")

                    with torch.no_grad():
                        flow_model.eval()
                        for _, (test_batch_spikes, test_batch_rates, test_batch_latents) in enumerate(test_dataloader):
                            test_batch_spikes = test_batch_spikes.clone().detach().to(device)
                            test_batch_rates = test_batch_rates.clone().detach().to(device)
                            test_batch_latents = test_batch_latents.clone().detach().to(device)

                            # noisy latent features
                            sample_num = test_batch_rates.shape[0]
                            z_0 = torch.randn(sample_num, flow_model.hidden_size, device=device)
                            z_0 = torch.unsqueeze(z_0, dim=1)

                            sample_model_kwargs = dict(y=test_batch_rates[:, :, :n_chan])
                            samples = sample_fn(z_0, model_fn, **sample_model_kwargs)[-1]
                            samples = torch.squeeze(samples)

                            pinv_decoder = torch.linalg.pinv(flow_model.linear_encoder.weight.t())
                            dec_out_valid = (samples - flow_model.linear_encoder.bias) @ pinv_decoder

                            y_true = torch.squeeze(test_batch_latents[:sample_num]).clone().detach()
                            y_pred = dec_out_valid[:sample_num].clone().detach()

                            pred_rates = torch.squeeze(flow_model.reconstruction_decoder(samples))
                            pred_rates = pred_rates.exp().clone().detach().cpu().numpy()
                            co_bps = bits_per_spike(pred_rates, torch.squeeze(test_batch_rates[:sample_num, -1, n_chan:]).clone().detach().cpu().numpy())
                            print("test co-bps: %.4f" % co_bps)

                            r2_score_test_tmp = r2_score(torch.reshape(y_true, (-1, y_true.size(-1))).cpu().detach().numpy(), torch.reshape(y_pred, (-1, y_pred.size(-1))).clone().cpu().detach().numpy())
                            print("test r2 score: %.4f" % r2_score_test_tmp)


                            y_pred_plot = torch.reshape(y_pred, (test_latents.shape[0], -1, test_latents.shape[-1]))
                            y_true_plot = torch.reshape(y_true, (test_latents.shape[0], -1, test_latents.shape[-1]))

                            if best_test_r2 < r2_score_test_tmp:
                                best_test_r2 = r2_score_test_tmp
                                best_test_co_bps = co_bps

                                ckpt_dir = './checkpoints/simulation/pretrain/mean_{}/seed_{}'.format(mean_rate, seed)
                                if not os.path.exists(ckpt_dir):
                                    os.makedirs(ckpt_dir)
                                # save the model
                                torch.save({
                                    'model_state_dict': flow_model.state_dict(),
                                    'optimizer_state_dict': optimizer.state_dict(),
                                    'loss_curve': train_loss,
                                    'mean_rate': mean_rate,
                                    'y_pred': y_pred_plot,
                                    'y_true': y_true_plot,
                                }, os.path.join(ckpt_dir, 'best_fm_model.pt'))

            print(f"best test r2 score: {best_test_r2:.4f}, best test co-bps: {best_test_co_bps:.4f}")
            test_r2_tmp.append(best_test_r2)
            test_co_bps_tmp.append(best_test_co_bps)

        mean, std = mean_std_pure(test_r2_tmp[1:])
        test_r2_tmp.append(mean)
        test_r2_tmp.append(std)
        final_r2_score.append(test_r2_tmp)

        mean, std = mean_std_pure(test_co_bps_tmp[1:])
        test_co_bps_tmp.append(mean)    
        test_co_bps_tmp.append(std)
        final_r2_score.append(test_co_bps_tmp)