import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import time
import os
import sys
from sklearn.metrics import f1_score, accuracy_score

# NWB imports and DANDI imports
from pynwb import NWBHDF5IO

from dmhp import model, utils, hmmglm

import argparse


## arguments
parser = argparse.ArgumentParser()
parser.add_argument('idx', type=int)
args = parser.parse_args()

folder = 'precut_spiketrains/type2_discrimination'
method_list = ['GLM', 'HMM-GLM', 'Gaussian HMM-GLM', 'One-hot HMM-GLM']
session_list = os.listdir(folder)
seq_range_list = ['nonrandom', 'random', 'all']

arg_index = np.unravel_index(args.idx, (len(method_list), len(session_list), len(seq_range_list)))
method, session, seq_range = method_list[arg_index[0]], session_list[arg_index[1]][:-4], seq_range_list[arg_index[2]]
if method == 'GLM':
    n_states = 1
else:
    n_states = 5


# ## read data
# io = NWBHDF5IO('./sub-228CR_ses-20190716T182623_behavior+ecephys+image.nwb', mode='r', load_namespaces=True)

# # Read the file
# nwbfile = io.read()

df_data = pd.read_pickle(f'{folder}/{session}.pkl')
spikes_list = torch.from_numpy(np.array(df_data['trial_spike_trains'])).to(torch.float32).permute((1, 2, 0))
if seq_range == 'nonrandom':
    spikes_list = spikes_list[~torch.from_numpy(np.array(df_data['randoms']))][:30]
elif seq_range == 'random':
    spikes_list = spikes_list[torch.from_numpy(np.array(df_data['randoms']))][:30]
n_seq, n_time_bins, n_neurons = spikes_list.shape

## hyperparameters
decay = 5
window_size = 5
T = 6
dt = T / n_time_bins
basis = utils.exp_basis(decay=decay, window_size=window_size, time_span=window_size*dt)


# ## data preparation
# n_neurons = len(nwbfile.units)
# timestamps_list = []
# for neuron in range(n_neurons):
#     timestamps_list.append(nwbfile.units[neuron].spike_times.values[0])

# n_time_bins = int(T/dt)
# n_seq = 21
# start_seq = 50
# spikes_list = torch.zeros((n_seq, n_time_bins, n_neurons))
# states_list = torch.zeros((n_seq, n_time_bins), dtype=torch.int64)
# for seq in range(n_seq):
#     start = nwbfile.trials[start_seq + seq].response_window_open_time.values[0] - 2
#     spikes_list[seq] = torch.from_numpy(utils.continuous_to_discrete(timestamps_list, dt, T, start=start)).to(torch.float32)
#     states_list[seq, 150:250] = 1

convolved_spikes_list = utils.convolve_spikes_with_basis(spikes_list, basis)
n_seq_train = int(n_seq * 2 / 3)


torch.manual_seed(0)
training_idx_list = torch.randperm(n_seq)[:n_seq_train]

if method == 'GLM':
    bg_intensity = spikes_list[training_idx_list].mean(dim=(0, 1)).logit()
    weight = 0.001 * (torch.rand((1, n_neurons, n_neurons)) - 0.5)

    inf_model = hmmglm.HMMGLM(1, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight)
    inf_model.load_state_dict(torch.load(f'model/{method}_{session}_{n_states}_{seq_range}.pt'))
    # optimizer = torch.optim.Adam(inf_model.parameters())

    # n_epochs = 200
    # print_freq = 20
    # for epoch in range(n_epochs):
    #     for seq in training_idx_list:
    #         firing_rates = inf_model.firing_rates(convolved_spikes_list[seq], states=torch.zeros(n_time_bins, dtype=torch.int64)) # n_states x n_time_bins x n_neurons
    #         loss = -utils.log_likelihood(spikes_list[seq], firing_rates).sum()
    #         if loss.isnan() == True:
    #                 sys.exit(0)
    #         optimizer.zero_grad()
    #         loss.backward()
    #         optimizer.step()
    #     if epoch % print_freq == 0:
    #         print(epoch, loss.item(), flush=True)

elif method == 'HMM-GLM':
    bg_intensity = spikes_list[training_idx_list].mean(dim=(0, 1)).logit().repeat((n_states, 1))
    weight = 0.001 * (torch.rand((n_states, n_neurons, n_neurons)) - 0.5)

    inf_model = hmmglm.HMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight)
    inf_model.load_state_dict(torch.load(f'model/{method}_{session}_{n_states}_{seq_range}.pt'))
    # optimizer = torch.optim.Adam(inf_model.parameters())

    # n_epochs = 10
    # print_freq = 1
    # for epoch in range(n_epochs):
    #     for seq in training_idx_list:
    #         gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
    #         for sub_epoch in range(100):
    #             loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi, update_transition_matrix=True)
    #             if loss.isnan() == True:
    #                 sys.exit(0)
    #             optimizer.zero_grad()
    #             loss.backward()
    #             optimizer.step()
    #     if epoch % print_freq == 0:
    #         print(epoch, loss.item(), flush=True)

elif method == 'Gaussian HMM-GLM':
    bg_intensity = spikes_list[training_idx_list].mean(dim=(0, 1)).logit().repeat((n_states, 1))
    weight = 0.001 * (torch.rand((n_states, n_neurons, n_neurons)) - 0.5)

    inf_model = hmmglm.HMMGLMGlobal(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight, sigma=torch.tensor(dt*3))
    inf_model.load_state_dict(torch.load(f'model/{method}_{session}_{n_states}_{seq_range}.pt'))
    # optimizer = torch.optim.Adam(inf_model.parameters())

    # n_epochs = 10
    # print_freq = 1
    # for epoch in range(n_epochs):
    #     for seq in training_idx_list:
    #         gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
    #         for sub_epoch in range(100):
    #             inf_model.update_w_prior()
    #             loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi, update_transition_matrix=True) - inf_model.prior_log_likelihood()
    #             if loss.isnan() == True:
    #                 sys.exit(0)
    #             optimizer.zero_grad()
    #             loss.backward()
    #             optimizer.step()
    #     if epoch % print_freq == 0:
    #         print(epoch, loss.item(), flush=True)

elif method == 'One-hot HMM-GLM':
    bg_intensity = spikes_list[training_idx_list].mean(dim=(0, 1)).logit().repeat((n_states, 1))
    logit_strength = torch.logit(0.5 * torch.ones((n_states, n_neurons, n_neurons)))
    log_adjacency = torch.randn((n_states, n_neurons, n_neurons, 3))
    log_adjacency[:, :, :, 1] += 2

    weight_tau = 0.2
    inf_model = hmmglm.HMMOnehotGLMGlobal(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, log_adjacency=log_adjacency, logit_strength=logit_strength, weight_tau=weight_tau, strength_nonlinearity='softplus')
    inf_model.load_state_dict(torch.load(f'model/{method}_{session}_{n_states}_{seq_range}.pt'))
    # optimizer = torch.optim.Adam(inf_model.parameters(), lr=0.01)

    # n_epochs = 10
    # print_freq = 1
    # for epoch in range(n_epochs):
    #     for seq in training_idx_list:
    #         gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
    #         inf_model.gumbel_softmax_weight = True
    #         for sub_epoch in range(100):
    #             loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi, update_transition_matrix=True) - 1 * inf_model.prior_log_likelihood() + 1 * inf_model.prior_entropy()
    #             if loss.isnan() == True:
    #                 sys.exit(0)
    #             optimizer.zero_grad()
    #             loss.backward()
    #             optimizer.step()
    #         inf_model.gumbel_softmax_weight = False
    #     if epoch % print_freq == 0:
    #         print(epoch, loss.item(), flush=True)


# torch.save(inf_model.state_dict(), f'model/{method}_{session}_{n_states}_{seq_range}.pt')


# df = pd.DataFrame(index=np.arange(n_seq), columns=['log-likelihood', 'train/test', 'seq'])

# with torch.no_grad():
#     gamma_list = torch.zeros((n_seq, n_time_bins, n_states))
#     for seq in range(n_seq):
#         if method != 'GLM':
#             gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
#             gamma_list[seq] = gamma
#             states_pred = gamma.argmax(dim=1)
#             firing_rates_pred = inf_model.firing_rates(convolved_spikes_list[seq], states_pred)
#         else:
#             firing_rates_pred = inf_model.firing_rates(convolved_spikes_list[seq], states=torch.zeros(n_time_bins, dtype=torch.int64)) # n_states x n_time_bins x n_neurons
#         df.at[seq, 'log-likelihood'] = utils.log_likelihood(spikes_list[seq], firing_rates_pred).sum().item()
#         df.at[seq, 'seq'] = seq
#         if seq in training_idx_list:
#             df.at[seq, 'train/test'] = 'train'
#         else:
#             df.at[seq, 'train/test'] = 'test'

# df.to_csv(f'csv/{method}_{session}_{n_states}_{seq_range}.csv', index=False)


## extra work
with torch.no_grad():
    gamma_list = torch.zeros((n_seq, n_time_bins, n_states))
    for seq in range(n_seq):
        if method != 'GLM':
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            gamma_list[seq] = gamma
            states_pred = gamma.argmax(dim=1)
            firing_rates_pred = inf_model.firing_rates(convolved_spikes_list[seq], states_pred)
        else:
            firing_rates_pred = inf_model.firing_rates(convolved_spikes_list[seq], states=torch.zeros(n_time_bins, dtype=torch.int64)) # n_states x n_time_bins x n_neurons
torch.save(gamma_list, f'state_pred/{method}_{session}_{n_states}_{seq_range}.pt')

# df = pd.DataFrame(index=[0], columns=['weight', 'adjacency', 'adj_prior'], dtype=object)
# with torch.no_grad():
#     df.at[0, 'weight'] = inf_model.weight
#     df.at[0, 'adjacency'] = inf_model.adjacency.argmax(dim=-1) - 1
#     if method == 'GLM':
#         df.at[0, 'adj_prior'] = utils.weight_to_adjacency_index(inf_model.weight)
#     elif method == 'HMM-GLM':
#         df.at[0, 'adj_prior'] = utils.weight_to_adjacency_index(inf_model.weight.mean(dim=0))
#     elif method == 'Gaussian HMM-GLM':
#         df.at[0, 'adj_prior'] = utils.weight_to_adjacency_index(inf_model.w_prior)
#     elif method == 'One-hot HMM-GLM':
#         df.at[0, 'adj_prior'] = inf_model.adj_prior.argmax(dim=-1) - 1
# df.to_pickle(f'weight_pred/{method}_{session}_{n_states}_{seq_range}.pkl')