import os
import torch
from copy import deepcopy
from sklearn.metrics import r2_score
from collections import defaultdict

from test_neuronal_dynamics_synthetic import setup_seed, load_data_loader, bits_per_spike, mean_std_pure

# corss-day: distribution of mean firing rates changed (reference)
if __name__ == "__main__":
    win_size = 5

    mean_rate_list = [0.1, 0.2]
    pre_mean_rate = 0.05

    out_test_r2 = defaultdict(list)
    out_test_co_bps = defaultdict(list)
    for mean_rate in mean_rate_list:
        # load synthetic data
        simu_data_dir = './data/simulation'
        simu_data_dir = os.path.join(simu_data_dir, 'mean_{}'.format(mean_rate))
        if not os.path.exists(simu_data_dir):
            raise FileNotFoundError("Simulation data directory does not exist: {}".format(simu_data_dir))
        train_data_file = os.path.join(simu_data_dir, 'train_data.pt')
        val_data_file = os.path.join(simu_data_dir, 'val_data.pt')
        test_data_file = os.path.join(simu_data_dir, 'test_data.pt')

        train_spikes = torch.load(train_data_file)['train_spikes']
        train_rates = torch.load(train_data_file)['train_rates']
        train_latents = torch.load(train_data_file)['train_latents']

        val_spikes = torch.load(val_data_file)['val_spikes']
        val_rates = torch.load(val_data_file)['val_rates']
        val_latents = torch.load(val_data_file)['val_latents']

        test_spikes = torch.load(test_data_file)['test_spikes']
        test_rates = torch.load(test_data_file)['test_rates']
        test_latents = torch.load(test_data_file)['test_latents']

        # generate dataloader
        win_size = 5
        train_dataloader = load_data_loader(train_spikes, train_rates, train_latents, win_size, is_shuffle=False)
        val_dataloader = load_data_loader(val_spikes, val_rates, val_latents, win_size, is_shuffle=False, is_batch=False)
        test_dataloader = load_data_loader(test_spikes, test_rates, test_latents, win_size , is_shuffle=False, is_batch=False)

        from model.vanilla_iTransfomer import ConditionModel
        from config.model_config import ModelConfig
        from flow.models.SiT_models import SiT

        from flow.transport.transport import create_transport, Sampler

        device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device('cpu')
        seed_list = [0, 1, 2, 3, 4]
        invert_flag = False
        
        # load pre-trained model
        context_size = win_size
        n_chan  = train_rates[0].shape[1]//2 if not invert_flag else context_size
        seq_len = context_size if not invert_flag else train_rates[0].shape[1]
        configs = ModelConfig(
            seq_len=seq_len,
            enc_in=n_chan,
            e_layers=2,
            factor=1,
            n_heads=8,
            d_model=n_chan, 
        )
        transformer_model = ConditionModel(configs)

        # SiT model settings
        flow_model = SiT(
            in_channels=n_chan if not invert_flag else seq_len,
            window_size=1,
            hidden_size=n_chan,
            out_dim=train_latents.shape[-1],
            diff_dim=n_chan,
            depth=5,
            mlp_ratio=2.0,
            num_heads=8,
            model_config=configs,
            target_latent_config=configs,
            invert_flag=invert_flag,
        )

        transport = create_transport(
            path_type="Linear",
            prediction="velocity",
            loss_weight=None,
            train_eps=None,
            sample_eps=1e-1,
        ) # default: velocity
        transport_sampler = Sampler(transport)
        sample_fn = transport_sampler.sample_ode(num_steps=2, sampling_method="euler")

        # ft
        ft_num_list = [0, 3, 5]
        nll_func = torch.nn.PoissonNLLLoss() 

        r2_list, co_bps_list = [], []
        for ft_num in ft_num_list:
            test_r2_list, test_co_bps_list = [ft_num], [ft_num]
            print(f"Fine-tuning with {ft_num} samples")
            for seed in seed_list:
                setup_seed(seed)
                ckpt_dir = './checkpoints/simulation/pretrain/mean_{}/seed_{}'.format(pre_mean_rate, seed)
                model_weight = torch.load(os.path.join(ckpt_dir, 'best_fm_model.pt'))['model_state_dict']
                flow_model.load_state_dict(model_weight, strict=True)

                model_fn = flow_model.forward
                flow_model.to(device)
                # only fine-tuning conditional models
                optimizer_cond = torch.optim.Adam(flow_model.dynamic_embedder.parameters(), lr=1e-3, weight_decay=1e-4)
                optimizer_decoder = torch.optim.Adam(flow_model.reconstruction_decoder.parameters(), lr=1e-3, weight_decay=1e-4) 
                best_flow_model = deepcopy(flow_model).to(device)

                train_loss = []
                if ft_num > 0:
                    train_spikes_ft, train_rates_ft, train_latents_ft = train_spikes[:ft_num], train_rates[:ft_num], train_latents[:ft_num]
                    train_dataloader_ft = load_data_loader(train_spikes_ft, train_rates_ft, train_latents_ft, win_size, is_shuffle=False, is_batch=True)

                    training_step = 300
                    sample_every = 10
                    best_val_co_bps = -1.0
                    for global_step in range(training_step):
                        flow_model.train()

                        for batch_idx, (train_batch_spikes, train_batch_rates, train_batch_latents) in enumerate(train_dataloader_ft):
                            train_batch_spikes = train_batch_spikes.clone().detach().to(device)
                            train_batch_rates = train_batch_rates.clone().detach().to(device)
                            train_batch_latents = train_batch_latents.clone().detach().to(device)

                            with torch.no_grad():
                                exp_z_manifold = flow_model.linear_encoder(train_batch_latents)

                            model_kwargs = dict(y=train_batch_rates[:, :, :n_chan])
                            loss_dict = transport.training_losses(flow_model, exp_z_manifold, model_kwargs)
                            loss = loss_dict["loss"].mean()

                            # noisy latent features
                            sample_num = train_batch_rates.shape[0]
                            z_0 = torch.randn(sample_num, flow_model.hidden_size, device=device)
                            z_0 = torch.unsqueeze(z_0, dim=1)

                            sample_model_kwargs = dict(y=train_batch_rates[:, :, :n_chan], is_cond=False)
                            samples = sample_fn(z_0, model_fn, **sample_model_kwargs)[-1]
                            samples = torch.squeeze(samples)

                            train_pred_rates = torch.squeeze(flow_model.reconstruction_decoder(samples.clone().detach()))

                            loss += nll_func(train_pred_rates, train_batch_rates[:, -1, n_chan:])
                            train_loss.append(loss.item())

                            print(f"global step: {global_step+1}/{training_step}, batch idx: {batch_idx}, loss: {loss.item():.4f}")

                            optimizer_cond.zero_grad()
                            optimizer_decoder.zero_grad()
                            loss.backward()
                            optimizer_cond.step()
                            optimizer_decoder.step()
                        
                        if (global_step + 1) % sample_every == 0:
                            with torch.no_grad():
                                flow_model.eval()
                                for _, (test_batch_spikes, test_batch_rates, test_batch_latents) in enumerate(val_dataloader):
                                    test_batch_spikes = test_batch_spikes.clone().detach().to(device)
                                    test_batch_rates = test_batch_rates.clone().detach().to(device)
                                    test_batch_latents = test_batch_latents.clone().detach().to(device)

                                    # noisy latent features
                                    sample_num = test_batch_rates.shape[0]
                                    z_0 = torch.randn(sample_num, flow_model.hidden_size, device=device)
                                    z_0 = torch.unsqueeze(z_0, dim=1)

                                    sample_model_kwargs = dict(y=test_batch_rates[:, :, :n_chan])
                                    samples = sample_fn(z_0, model_fn, **sample_model_kwargs)[-1]
                                    samples = torch.squeeze(samples)

                                    pinv_decoder = torch.linalg.pinv(flow_model.linear_encoder.weight.t())
                                    dec_out_valid = (samples - flow_model.linear_encoder.bias) @ pinv_decoder

                                    y_true = torch.squeeze(test_batch_latents[:sample_num]).clone().detach()
                                    y_pred = dec_out_valid[:sample_num].clone().detach()

                                    pred_rates = torch.squeeze(flow_model.reconstruction_decoder(samples.clone().detach()))
                                    pred_rates = pred_rates.exp().clone().detach().cpu().numpy()

                                    co_bps = bits_per_spike(pred_rates, torch.squeeze(test_batch_rates[:sample_num, -1, n_chan:]).clone().detach().cpu().numpy())
                                    print("val co-bps: %.4f" % co_bps)

                                    r2_score_test_tmp = r2_score(torch.reshape(y_true, (-1, y_true.size(-1))).cpu().detach().numpy(), torch.reshape(y_pred, (-1, y_pred.size(-1))).clone().cpu().detach().numpy())
                                    print("val r2 score: %.4f" % r2_score_test_tmp)

                                    if best_val_co_bps < co_bps:
                                        best_val_co_bps = co_bps
                                        best_flow_model = deepcopy(flow_model).to(device)
                    print("best val co-bps: %.4f" % best_val_co_bps)

                # test
                with torch.no_grad():
                    best_flow_model.eval()
                    test_model_fn = best_flow_model.forward

                    for _, (test_batch_spikes, test_batch_rates, test_batch_latents) in enumerate(test_dataloader):
                        test_batch_spikes = test_batch_spikes.clone().detach().to(device)
                        test_batch_rates = test_batch_rates.clone().detach().to(device)
                        test_batch_latents = test_batch_latents.clone().detach().to(device)

                        # noisy latent features
                        sample_num = test_batch_rates.shape[0]
                        z_0 = torch.randn(sample_num, best_flow_model.hidden_size, device=device)
                        z_0 = torch.unsqueeze(z_0, dim=1)

                        sample_model_kwargs = dict(y=test_batch_rates[:, :, :n_chan])
                        samples = sample_fn(z_0, test_model_fn, **sample_model_kwargs)[-1]
                        samples = torch.squeeze(samples)

                        pinv_decoder = torch.linalg.pinv(best_flow_model.linear_encoder.weight.t())
                        dec_out_valid = (samples - best_flow_model.linear_encoder.bias) @ pinv_decoder

                        y_true = torch.squeeze(test_batch_latents[:sample_num]).clone().detach()
                        y_pred = dec_out_valid[:sample_num].clone().detach()

                        pred_rates = torch.squeeze(best_flow_model.reconstruction_decoder(samples))
                        pred_rates = pred_rates.exp().clone().detach().cpu().numpy()

                        co_bps = bits_per_spike(pred_rates, torch.squeeze(test_batch_rates[:sample_num, -1, n_chan:]).clone().detach().cpu().numpy())
                        print("test co-bps: %.4f" % co_bps)

                        r2_score_test_tmp = r2_score(torch.reshape(y_true, (-1, y_true.size(-1))).cpu().detach().numpy(), torch.reshape(y_pred, (-1, y_pred.size(-1))).clone().cpu().detach().numpy())
                        print("test r2 score: %.4f" % r2_score_test_tmp)

                        y_pred_plot = torch.reshape(y_pred, (test_latents.shape[0], -1, test_latents.shape[-1]))
                        y_true_plot = torch.reshape(y_true, (test_latents.shape[0], -1, test_latents.shape[-1]))

                        # save the model
                        ckpt_dir = './checkpoints/simulation/ft/mean_{}/ft_num_{}/seed_{}'.format(mean_rate, ft_num, seed)
                        if not os.path.exists(ckpt_dir):
                            os.makedirs(ckpt_dir) 
                        torch.save({
                            'model_state_dict': best_flow_model.state_dict(),
                            'optimizer_state_dict': optimizer_cond.state_dict(),
                            'loss_curve': train_loss,
                            'mean_rate': mean_rate,
                            'y_pred': y_pred_plot,
                            'y_true': y_true_plot,
                        }, os.path.join(ckpt_dir, 'best_fm_model.pt'))

                        test_r2_list.append(r2_score_test_tmp)
                        test_co_bps_list.append(co_bps)
            
            # compute mean and std
            mean, std = mean_std_pure(test_r2_list[1:])
            test_r2_list.append(mean)
            test_r2_list.append(std)

            mean, std = mean_std_pure(test_co_bps_list[1:])
            test_co_bps_list.append(mean)
            test_co_bps_list.append(std)

            r2_list.append(test_r2_list)
            co_bps_list.append(test_co_bps_list)

        out_test_r2[mean_rate] = r2_list
        out_test_co_bps[mean_rate] = co_bps_list