from ndt1.src.config.settings import settings
from ndt1.validate_test_submission import get_top_model_metadata, get_ckpt_and_datafile
from ndt1.src.inference import get_data_generator, get_batch_output_factors
from ndt1.src.dataset import DATASET_MODES
from nlb_tools.nwb_interface import NWBDataset
from nlb_tools.evaluation import evaluate

import os
import torch
import numpy as np

from config.model_config import ModelConfig
from flow.models.SiT_models import SiT
from flow.transport.transport import create_transport, Sampler
from preprocess.data_loader import load_behavioral_variables

from copy import deepcopy
from collections import OrderedDict

@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)

def set_seed(seed):
    """
    Set random seed for reproducibility
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

if __name__ == "__main__":
    # mc_rtt
    dataset_name = "dmfc_rsg"
    model_checkpoint_root = settings.SUBMISSION_VALIDATION_ROOT
    data_pth_dict = settings.NLB_DATA_PTH_DICT
    save_ckpt_dir = settings.CHECKPOINT_DIR

    AESMTE3_ENSEMBLE_SIZES = {
        "area2_bump": 21, # 2 (num_heads), 4(depth)
        "dmfc_rsg": 13, # 2 (num_heads), 4(depth)
        "mc_maze": 8, # 2 (num_heads), 4(depth)
        "mc_maze_large": 8,
        "mc_maze_medium": 8,
        "mc_maze_small": 7,
        "mc_rtt": 13, # 13 5(num_heads), 4(depth)
    }
    n_models = AESMTE3_ENSEMBLE_SIZES[dataset_name]
    models_df = get_top_model_metadata(dataset_name, n_models, model_checkpoint_root)
    
    phase = "val"
    rows = list(models_df.itertuples())
    runner, *_ = get_ckpt_and_datafile(
        dataset_name=dataset_name,
        phase=phase,
        model_checkpoint_root=model_checkpoint_root,
        trial_id=rows[0].trial_id,
        val_ckpt=rows[0].val_ckpt,
    )

    # Create suffix for group naming later
    bin_width = 5  # in ms
    suffix = '' if (bin_width == 5) else f'_{int(round(bin_width))}'

    # load behavioral variables
    datapath = data_pth_dict[dataset_name]
    dataset = NWBDataset(datapath)
    train_beh, eval_beh = None, None # area2_bump

    batch_size = 32
    train_data_generator = get_data_generator(
        runner=runner,
        mode=DATASET_MODES.train,
        batch_size=batch_size,
        beh_vars=train_beh,
    )
    eval_data_generator = get_data_generator(
        runner=runner,
        mode=DATASET_MODES.val,
        batch_size=batch_size,
        beh_vars=eval_beh,
    )

    # If 'val' phase, make the target data
    target_dict = None

    import pickle
    if phase == 'val':
        datapath = data_pth_dict[dataset_name]
        with open(os.path.join(datapath, 'val_target_dict.pkl'), 'rb') as f:
            target_dict = pickle.load(f)

    seed_list = [0, 1, 2, 3, 4]
    for seed in seed_list:
        m_idx = 0
        for row in models_df.itertuples():
            batch_output_factors = []
            for batch_idx, (spikes, _, heldout_spikes, forward_spikes, t) in enumerate(train_data_generator):
                rel = []
                cnt = 0
                
                runner, ckpt_path, ndt_file_path = get_ckpt_and_datafile(
                    dataset_name=dataset_name,
                    phase=phase,
                    model_checkpoint_root=model_checkpoint_root,
                    trial_id=row.trial_id,
                    val_ckpt=row.val_ckpt,
                )
                # get the output factors for the batch
                batch_output = get_batch_output_factors(
                    ckpt_path=ckpt_path,
                    data_file=ndt_file_path,
                    batch_data=(spikes, None, heldout_spikes, forward_spikes),
                )

                # ensemble output
                batch_output_factors.append(batch_output['batch_factors'])   
            
            # target latent space
            target_batch_output_factors = []
            for batch_idx, (spikes, _, heldout_spikes, forward_spikes, _) in enumerate(eval_data_generator):
                rel = []
                cnt = 0
                runner, ckpt_path, ndt_file_path = get_ckpt_and_datafile(
                    dataset_name=dataset_name,
                    phase=phase,
                    model_checkpoint_root=model_checkpoint_root,
                    trial_id=row.trial_id,
                    val_ckpt=row.val_ckpt,
                )
                # get the output factors for the batch
                batch_output = get_batch_output_factors(
                    ckpt_path=ckpt_path,
                    data_file=ndt_file_path,
                    batch_data=(spikes, None, heldout_spikes, forward_spikes),
                )

                # ensemble output
                target_batch_output_factors.append(batch_output['batch_factors'])
            
            set_seed(seed=seed) 

            # get the first batch to get the shape of the data
            spikes, _, heldout_spikes, forward_spikes, _ = next(iter(train_data_generator))

            window_size = spikes.size(1) + forward_spikes.size(1)  # total time steps in the window
            window_size_heldin = spikes.size(1)  # number of time steps in the heldin data

            n_chan = forward_spikes.size(-1)
            n_chan_heldin = runner.model.embedder(spikes.clone().detach().to(runner.device)).size(-1) # total number of channels

            # target latent dims
            latent_dim = batch_output_factors[0].size(-1)

            model_configs = ModelConfig(
                seq_len=window_size_heldin,
                pred_len=window_size,
                enc_in=n_chan_heldin,
                d_model=latent_dim,

                e_layers=2,
                d_pos=n_chan,
                factor=8,
            )
            invert_flag = False

            pred_len=1
            flow_model = SiT(
                in_channels=n_chan_heldin,
                window_size=pred_len,
                hidden_size=latent_dim,
                out_dim=n_chan,
                beh_dim=2,

                num_heads=2,
                depth=4,
                mlp_ratio=2.0,
                model_config=model_configs,
                target_latent_config=None,
                cond_model=None,
                beh_config=None,
                invert_flag=invert_flag,
            )

            # set the device for the model
            flow_model.to(runner.device)
            
            # Note that parameter initialization is done within the SiT constructor
            ema = deepcopy(flow_model).to(runner.device)  # Create an EMA of the model for use after training
            model_fn = ema.forward

            transport = create_transport(
                path_type='Linear',
                prediction='velocity',
                loss_weight=None,
                train_eps=None,
                sample_eps=1e-2,
            )
            transport_sampler = Sampler(transport)
            sample_fn = transport_sampler.sample_ode(num_steps=2, sampling_method="euler")

            # set the model optimizer
            optimizer = torch.optim.Adam(flow_model.parameters(), lr=1e-3)

            # train flow models
            training_epoch = 1500
            sample_every = 20
            best_val_rel = None
            nll_func = torch.nn.PoissonNLLLoss()

            for epoch_id in range(training_epoch):
                total_train_rates_heldin, total_train_rates_heldout = [], []
                flow_model.train()

                for batch_idx, (spikes, _, heldout_spikes, forward_spikes, beh_vars) in enumerate(train_data_generator):
                    gt_spikes = torch.cat([spikes, heldout_spikes], -1)
                    gt_spikes = torch.cat([gt_spikes, forward_spikes], 1)
                    gt_spikes = gt_spikes.detach().clone().to(dtype=torch.float32, device=runner.device)


                    in_spikes = spikes.detach().clone().to(device=runner.device)
                    in_spikes = in_spikes.permute(1, 0, 2)
                    in_spikes = runner.model.embedder(in_spikes) * runner.model.scale
                    in_spikes = in_spikes.permute(1, 0, 2)

                    with torch.no_grad():
                        batch_factors = batch_output_factors[batch_idx].to(runner.device)
                        exp_z_manifold = torch.reshape(batch_factors, (-1, batch_factors.size(-1)))
                        exp_z_manifold = torch.unsqueeze(exp_z_manifold, dim=1)

                    model_kwargs = dict(y=in_spikes, is_cond=False)
                    loss_dict = transport.training_losses(flow_model, exp_z_manifold, model_kwargs)
                    loss = loss_dict['loss'].mean()

                    # inference
                    train_num = int(spikes.size(0)*window_size)
                    z_0 = torch.randn(train_num, latent_dim, device=runner.device)
                    z_0 = torch.unsqueeze(z_0, dim=1)

                    sample_model_kwargs = dict(y=in_spikes, is_cond=False)
                    samples = sample_fn(z_0, model_fn, **sample_model_kwargs)[-1]
                    samples = torch.reshape(samples, (spikes.size(0), -1, latent_dim))

                    train_log_rates = flow_model.linear_decoder(samples)
                    loss += nll_func(train_log_rates, gt_spikes)
                    train_log_rates = train_log_rates.exp()

                    train_rates, _ = torch.split(train_log_rates, [spikes.size(1), train_log_rates.size(1) - spikes.size(1)], 1)
                    train_rates_heldin, train_rates_heldout = torch.split(train_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)

                    total_train_rates_heldin.append(train_rates_heldin)
                    total_train_rates_heldout.append(train_rates_heldout)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    update_ema(ema, flow_model, decay=0.999)

                    print(f"Epoch {epoch_id+1}/{training_epoch}, Batch {batch_idx+1}/{len(train_data_generator)}, Loss: {loss.item():.4f}")
                
                total_train_rates_heldin = torch.cat(total_train_rates_heldin, dim=0)
                total_train_rates_heldout = torch.cat(total_train_rates_heldout, dim=0)

                if (epoch_id+1) % sample_every == 0:
                    flow_model.eval()
                    total_eval_rates = []

                    with torch.no_grad():
                        for batch_idx, (spikes, _, heldout_spikes, forward_spikes, beh_vars) in enumerate(eval_data_generator):

                            eval_in_spikes = spikes.detach().clone().to(device=runner.device)
                            eval_in_spikes = eval_in_spikes.permute(1, 0, 2)
                            eval_in_spikes = runner.model.embedder(eval_in_spikes) * runner.model.scale
                            eval_in_spikes = eval_in_spikes.permute(1, 0, 2)

                            # noisy latent factors
                            eval_num = int(eval_in_spikes.size(0)*window_size)
                            z_0 = torch.randn(eval_num, latent_dim, device=runner.device)
                            z_0 = torch.unsqueeze(z_0, dim=1)
                            
                            sample_model_kwargs = dict(y=eval_in_spikes, is_cond=False)
                            samples = sample_fn(z_0, model_fn, **sample_model_kwargs)[-1]

                            samples = torch.reshape(samples, (spikes.size(0), -1, latent_dim))

                            pred_rates = flow_model.linear_decoder(samples).exp()


                            # add to total eval rates
                            total_eval_rates.append(pred_rates)

                            # behavioral decoding

                    total_eval_rates = torch.cat(total_eval_rates, dim=0)

                    eval_rates, eval_rates_forward = torch.split(total_eval_rates, [spikes.size(1), total_eval_rates.size(1) - spikes.size(1)], 1)
                    eval_rates_heldin_forward, eval_rates_heldout_forward = torch.split(eval_rates_forward, [spikes.size(-1), heldout_spikes.size(-1)], -1)
                    eval_rates_heldin, eval_rates_heldout = torch.split(eval_rates, [spikes.size(-1), heldout_spikes.size(-1)], -1)

                    output_dict = {
                        dataset_name + suffix: {
                            'train_rates_heldin': total_train_rates_heldin.detach().cpu().numpy(),
                            'train_rates_heldout': total_train_rates_heldout.detach().cpu().numpy(),
                            'eval_rates_heldin': eval_rates_heldin.detach().cpu().numpy(),
                            'eval_rates_heldout': eval_rates_heldout.detach().cpu().numpy(),
                            'eval_rates_heldin_forward': eval_rates_heldin_forward.detach().cpu().numpy(),
                            'eval_rates_heldout_forward': eval_rates_heldout_forward.detach().cpu().numpy(),
                        }
                    }
                    # evaluate
                    eval_rel = evaluate(target_dict, output_dict)
                    print(f"Epoch {epoch_id+1}/{training_epoch}, Output: {eval_rel[0]}")

                    # save checkpoint
                    co_bps = eval_rel[0][dataset_name+'_split']['co-bps']
                    if best_val_rel is None or co_bps > best_val_rel['co-bps']:
                        best_val_rel = eval_rel[0][dataset_name+'_split']
                        # save the model
                        save_model_dir = os.path.join(
                            save_ckpt_dir,
                            'nlb',
                            dataset_name,
                            str(seed),
                        )
                        if not os.path.exists(save_model_dir):
                            os.makedirs(save_model_dir, exist_ok=True)
                        save_model_file = os.path.join(save_model_dir, f"best_val_bps_fm_model_{m_idx}.pth")
                        torch.save({
                            'model_state_dict': ema.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'epoch': epoch_id,
                            'seed': seed,
                            'eval_rel': eval_rel[0][dataset_name+'_split'],
                        }, save_model_file)
                        print(f"Saved best model to {save_model_file}")


            print(f"Best validation results: {best_val_rel}")
            m_idx += 1