import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import time
from sklearn.metrics import f1_score, accuracy_score
import os

from dmhp import model, utils, hmmglm

import argparse


## arguments
parser = argparse.ArgumentParser()
parser.add_argument('idx', type=int)
args = parser.parse_args()

method_list = ['GLM', 'HMM-GLM', 'Gaussian HMM-GLM', 'One-hot HMM-GLM', 'HMM-GLM L1', 'Gaussian HMM-GLM L1', 'Engel', 'Ashwood']
seed_list = np.arange(10)
folder = 200104

arg_index = np.unravel_index(args.idx, (len(method_list), len(seed_list)))
method, seed = method_list[arg_index[0]], seed_list[arg_index[1]]
if method in ['GLM', 'One-hot GLM']:
    n_states = 1
else:
    n_states = 5


## read data
spike_data = pd.read_table(f'../../test/pfc-6/mPFC_Data/{folder}/{folder}_SpikeData.dat', header=None)
spike_data.columns = ['time', 'cell']
behavior = pd.read_table(f'../../test/pfc-6/mPFC_Data/{folder}/{folder}_Behavior.dat', header=None)
behavior.columns = ['start', 'end', 'rule', 'correct', 'left/right', 'light position']


## hyperparameters
decay = 5
window_size = 5
dt = 0.02
T = 15
basis = utils.exp_basis(decay=decay, window_size=window_size, time_span=window_size*dt)


## data preparation
n_neurons = len(spike_data.cell.unique())
timestamps_list = []
for neuron in range(n_neurons):
    timestamps_list.append(spike_data[spike_data.cell == neuron + 1].time.values / 1000)

n_time_bins = int(T/dt)
n_seq = behavior.shape[0]
spikes_list = torch.zeros((n_seq, n_time_bins, n_neurons))
states_list = torch.zeros((n_seq, n_time_bins), dtype=torch.int64)
for seq in range(n_seq):
    behavior_start = behavior.at[seq, 'start']/1000
    behavior_end = behavior.at[seq, 'end']/1000
    spikes_list[seq] = torch.from_numpy(utils.continuous_to_discrete(timestamps_list, dt, T, start=behavior_start-5)).to(torch.float32)
    states_list[seq, 250:(250 + int((behavior_end - behavior_start)/dt))] = 1

convolved_spikes_list = utils.convolve_spikes_with_basis(spikes_list, basis)
n_seq_train = int(n_seq * 2 / 3)


torch.manual_seed(seed)
if method == 'GLM':
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).clamp(max=0.99).logit()
    weight = 0.001 * (torch.rand((1, n_neurons, n_neurons)) - 0.5)

    inf_model = hmmglm.HMMGLM(1, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight)
    optimizer = torch.optim.Adam(inf_model.parameters())

    n_epochs = 200
    print_freq = 20
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            firing_rates = inf_model.firing_rates(convolved_spikes_list[seq], states=torch.zeros(n_time_bins, dtype=torch.int64)) # n_states x n_time_bins x n_neurons
            loss = -utils.log_likelihood(spikes_list[seq], firing_rates).sum()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method in ['HMM-GLM', 'HMM-GLM L1']:
    if method == 'HMM-GLM L1':
        penalty = 1
    else:
        penalty = 0
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).clamp(max=0.99).logit()
    weight = 0.001 * (torch.rand((n_states, n_neurons, n_neurons)) - 0.5)

    inf_model = hmmglm.HMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight)
    optimizer = torch.optim.Adam(inf_model.parameters())

    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            for sub_epoch in range(100):
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi) + penalty * inf_model.weight.abs().sum()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method in ['Gaussian HMM-GLM', 'Gaussian HMM-GLM L1']:
    if method == 'Gaussian HMM-GLM L1':
        penalty = 1
    else:
        penalty = 0
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).clamp(max=0.99).logit()
    weight = 0.001 * (torch.rand((n_states, n_neurons, n_neurons)) - 0.5)
    
    inf_model = hmmglm.GaussianHMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight)
    optimizer = torch.optim.Adam(inf_model.parameters())
    
    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            for sub_epoch in range(100):
                inf_model.update_w_prior()
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi) - inf_model.prior_log_likelihood() + penalty * inf_model.weight.abs().sum()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method == 'One-hot HMM-GLM':
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).clamp(max=0.99).logit()
    logit_strength = torch.logit(0.5 * torch.ones((n_states, n_neurons, n_neurons)))
    log_adjacency = torch.randn((n_states, n_neurons, n_neurons, 3))
    log_adjacency[:, :, :, 1] += 2

    weight_tau = 0.2
    inf_model = hmmglm.OnehotHMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, log_adjacency=log_adjacency, logit_strength=logit_strength, weight_tau=weight_tau, strength_nonlinearity='softplus')
    optimizer = torch.optim.Adam(inf_model.parameters(), lr=0.01)

    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            inf_model.gumbel_softmax_weight = True
            for sub_epoch in range(100):
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi, update_transition_matrix=True) - 1 * inf_model.prior_log_likelihood() + 1 * inf_model.prior_entropy()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            inf_model.gumbel_softmax_weight = False
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method == 'Scott HMM-GLM':
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).clamp(max=0.99).logit()
    weight = 0.001 * (torch.rand((n_states, n_neurons, n_neurons)) - 0.5)
    log_connection = torch.randn((n_states, n_neurons, n_neurons, 2))
    log_connection[:, :, :, 0] += 2

    weight_tau = 0.2
    inf_model = hmmglm.ScottHMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight, log_connection=log_connection, weight_tau=weight_tau)
    optimizer = torch.optim.Adam(inf_model.parameters(), lr=0.01)

    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            inf_model.gumbel_softmax_weight = True
            for sub_epoch in range(100):
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi, update_transition_matrix=True) - 1 * inf_model.prior_log_likelihood() + 1 * inf_model.prior_entropy()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            inf_model.gumbel_softmax_weight = False
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method == 'Engel':
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).clamp(max=0.99).logit() * torch.ones((n_states, 1))
    weight = torch.zeros((n_states, n_neurons, n_neurons))

    inf_model = hmmglm.HMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight, activation='softplus')
    inf_model._weight.requires_grad = False

    optimizer = torch.optim.Adam(inf_model.parameters())

    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            for sub_epoch in range(100):
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method == 'Ashwood':
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).clamp(max=0.99).logit() * torch.ones((n_states, 1))
    weight = torch.zeros((n_states, n_neurons, n_neurons))

    inf_model = hmmglm.HMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight, activation='sigmoid')
    inf_model._weight.requires_grad = False

    optimizer = torch.optim.Adam(inf_model.parameters())

    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            for sub_epoch in range(100):
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)


torch.save(inf_model.state_dict(), f'model/{method}_{n_states}_{folder}_{seed}.pt')


df = pd.DataFrame(index=np.arange(n_seq), columns=['log-likelihood', 'state accuracy', 'train/test', 'seq'])

with torch.no_grad():
    for seq in range(n_seq):
        if method not in ['GLM', 'One-hot GLM']:
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            states_pred = gamma.argmax(dim=1)
            firing_rates_pred = inf_model.firing_rates(convolved_spikes_list[seq], states_pred)
        else:
            firing_rates_pred = inf_model.firing_rates(convolved_spikes_list[seq], states=torch.zeros(n_time_bins, dtype=torch.int64)) # n_states x n_time_bins x n_neurons
        df.at[seq, 'log-likelihood'] = utils.log_likelihood(spikes_list[seq], firing_rates_pred).sum().item()
        df.at[seq, 'seq'] = seq
        if seq >= n_seq_train:
            df.at[seq, 'train/test'] = 'test'
        else:
            df.at[seq, 'train/test'] = 'train'
        if n_states == 2:
            one_hot_states = F.one_hot(states_list[seq])
            true_to_learned = utils.match_states(one_hot_states, gamma)
            df.at[seq, 'state accuracy'] = accuracy_score(states_list[seq], gamma[:, true_to_learned].argmax(dim=1))

df.to_csv(f'csv/{method}_{n_states}_{folder}_{seed}.csv', index=False)