import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import time
from sklearn.metrics import f1_score, accuracy_score, balanced_accuracy_score
import os

from dmhp import model, utils, hmmglm

import argparse


## arguments
parser = argparse.ArgumentParser()
parser.add_argument('idx', type=int)
args = parser.parse_args()
n_trials = 10

method_list = ['GLM', 'HMM-GLM', 'Gaussian HMM-GLM', 'One-hot HMM-GLM', 'HMM-GLM L1', 'Gaussian HMM-GLM L1', 'Engel', 'Ashwood', 'Scott HMM-GLM']
trial_list = np.arange(n_trials)

arg_index = np.unravel_index(args.idx, (len(method_list), len(trial_list)))
method, trial = method_list[arg_index[0]], trial_list[arg_index[1]]


## hyperparameters
window_size = 5
dt = 0.01
basis = utils.exp_basis(decay=1, window_size=window_size, time_span=window_size*dt)
T = 50


## read data
df_data = pd.read_pickle('data/data.pkl')
spikes_list = df_data.at[trial, 'spikes_list']
convolved_spikes_list = utils.convolve_spikes_with_basis(spikes_list, basis)
states_list = df_data.at[trial, 'states_list']
n_seq, n_time_bins, n_neurons = spikes_list.shape
n_seq_train = int(n_seq / 2)
if method in ['GLM', 'One-hot GLM']:
    n_states = 1
else:
    n_states = 5
gen_model = hmmglm.OnehotHMMGLM(5, n_neurons, dt=dt, basis=basis, logit_strength=torch.zeros((5, n_neurons, n_neurons)))
gen_model.load_state_dict(df_data.at[trial, 'gen_model'])

torch.manual_seed(trial)
if method == 'GLM':
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).logit()
    weight = 0.001 * (torch.rand((1, n_neurons, n_neurons)) - 0.5)

    inf_model = hmmglm.HMMGLM(1, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight)
    optimizer = torch.optim.Adam(inf_model.parameters())

    n_epochs = 200
    print_freq = 20
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            firing_rates = inf_model.firing_rates(convolved_spikes_list[seq], states=torch.zeros(n_time_bins, dtype=torch.int64)) # n_states x n_time_bins x n_neurons
            loss = -utils.log_likelihood(spikes_list[seq], firing_rates).sum()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method in ['HMM-GLM', 'HMM-GLM L1']:
    if method == 'HMM-GLM L1':
        penalty = 100
    else:
        penalty = 0
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).logit()
    weight = 0.001 * (torch.rand((n_states, n_neurons, n_neurons)) - 0.5)

    inf_model = hmmglm.HMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight)
    optimizer = torch.optim.Adam(inf_model.parameters())

    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            for sub_epoch in range(100):
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi) + penalty * inf_model.weight.abs().sum()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method in ['Gaussian HMM-GLM', 'Gaussian HMM-GLM L1']:
    if method == 'Gaussian HMM-GLM L1':
        penalty = 100
    else:
        penalty = 0
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).logit()
    weight = 0.001 * (torch.rand((n_states, n_neurons, n_neurons)) - 0.5)
    
    inf_model = hmmglm.GaussianHMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight)
    optimizer = torch.optim.Adam(inf_model.parameters())
    
    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            for sub_epoch in range(100):
                inf_model.update_w_prior()
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi) - inf_model.prior_log_likelihood() + penalty * inf_model.weight.abs().sum()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method == 'One-hot HMM-GLM':
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).logit()
    logit_strength = torch.logit(0.5 * torch.ones((n_states, n_neurons, n_neurons)))
    log_adjacency = torch.randn((n_states, n_neurons, n_neurons, 3))
    log_adjacency[:, :, :, 1] += 2

    weight_tau = 0.2
    inf_model = hmmglm.OnehotHMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, log_adjacency=log_adjacency, logit_strength=logit_strength, weight_tau=weight_tau, strength_nonlinearity='softplus')
    optimizer = torch.optim.Adam(inf_model.parameters(), lr=0.01)

    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            inf_model.gumbel_softmax_weight = True
            for sub_epoch in range(100):
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi, update_transition_matrix=True) - 1 * inf_model.prior_log_likelihood() + 1 * inf_model.prior_entropy()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            inf_model.gumbel_softmax_weight = False
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method == 'Scott HMM-GLM':
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).logit()
    weight = 0.001 * (torch.rand((n_states, n_neurons, n_neurons)) - 0.5)
    log_connection = torch.randn((n_states, n_neurons, n_neurons, 2))
    log_connection[:, :, :, 0] += 2

    weight_tau = 0.2
    inf_model = hmmglm.ScottHMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight, log_connection=log_connection, weight_tau=weight_tau)
    optimizer = torch.optim.Adam(inf_model.parameters(), lr=0.01)

    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            inf_model.gumbel_softmax_weight = True
            for sub_epoch in range(100):
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi, update_transition_matrix=True) - 1 * inf_model.prior_log_likelihood() + 1 * inf_model.prior_entropy()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            inf_model.gumbel_softmax_weight = False
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method == 'Engel':
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).logit() * torch.ones((n_states, 1))
    weight = torch.zeros((n_states, n_neurons, n_neurons))

    inf_model = hmmglm.HMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight, activation='softplus')
    inf_model._weight.requires_grad = False

    optimizer = torch.optim.Adam(inf_model.parameters(), lr=0.1)

    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            for sub_epoch in range(100):
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)

elif method == 'Ashwood':
    bg_intensity = spikes_list[:n_seq_train].mean(dim=(0, 1)).logit() * torch.ones((n_states, 1))
    weight = torch.zeros((n_states, n_neurons, n_neurons))

    inf_model = hmmglm.HMMGLM(n_states, n_neurons, dt=dt, basis=basis, bg_intensity=bg_intensity, weight=weight, activation='sigmoid')
    inf_model._weight.requires_grad = False

    optimizer = torch.optim.Adam(inf_model.parameters())

    n_epochs = 10
    print_freq = 1
    for epoch in range(n_epochs):
        for seq in range(n_seq_train):
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            for sub_epoch in range(100):
                loss = -inf_model.m_step(spikes_list[seq], convolved_spikes_list[seq], gamma, xi)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        if epoch % print_freq == 0:
            print(epoch, loss.item(), flush=True)


df_inf = pd.DataFrame(index=np.arange(n_seq), columns=['log-likelihood', 'state accuracy', 'state log-likelihood', 'train/test', 'seq'])
df_learn = pd.DataFrame(index=np.arange(1), columns=['weight error', 'adjacency log-likelihood', 'adjacency balanced accuracy', 'adjacency accuracy', 'adjacency prior balanced accuracy', 'adjacency prior accuracy'])

with torch.no_grad():
    gamma_list = torch.zeros((n_seq, n_time_bins, n_states))
    for seq in range(n_seq):
        if method not in ['GLM']:
            gamma, xi = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
            gamma_list[seq] = gamma
            states_pred = gamma.argmax(dim=1)
            firing_rates_pred = inf_model.firing_rates(convolved_spikes_list[seq], states_pred)
        else:
            firing_rates_pred = inf_model.firing_rates(convolved_spikes_list[seq], states=torch.zeros(n_time_bins, dtype=torch.int64)) # n_states x n_time_bins x n_neurons
        df_inf.at[seq, 'log-likelihood'] = utils.log_likelihood(spikes_list[seq], firing_rates_pred).sum().item()
        df_inf.at[seq, 'seq'] = seq
        if seq >= n_seq_train:
            df_inf.at[seq, 'train/test'] = 'test'
        else:
            df_inf.at[seq, 'train/test'] = 'train'
    
    if method not in ['GLM']:
        one_hot_states = F.one_hot(states_list[:n_seq_train].reshape((n_seq_train*n_time_bins,)))
        true_to_learned = utils.match_states(one_hot_states, gamma_list[:n_seq_train].reshape((n_seq_train*n_time_bins, n_states)))
        for seq in range(n_seq):
            df_inf.at[seq, 'state accuracy'] = accuracy_score(states_list[seq], gamma_list[seq, :, true_to_learned].argmax(dim=1))
            df_inf.at[seq, 'state log-likelihood'] = -F.cross_entropy(gamma_list[seq, :, true_to_learned].log(), states_list[seq]).item()
        inf_model.permute_states(true_to_learned)

    torch.save(inf_model.state_dict(), f'model/{method}_{trial}.pt')

    if method in ['Engel', 'Ashwood']:
        gamma_list = torch.zeros((n_seq_train, n_time_bins, n_states))
        for seq in range(n_seq_train):
            gamma_list[seq], __ = inf_model.forward_backward(spikes_list[seq], convolved_spikes_list[seq])
        states_list_pred = gamma_list.argmax(dim=-1)
        for state in range(n_states):
            spikes_state = spikes_list[:n_seq_train].reshape(-1, n_neurons)[states_list_pred.reshape(-1) == state]
            try:
                inf_model._weight[state] = torch.from_numpy(utils.ccg_diff(spikes_state)).to(torch.float32)
            except:
                inf_model._weight[state] = torch.from_numpy(np.corrcoef(spikes_state.T)).to(torch.float32)
        inf_model._weight.data = inf_model._weight / inf_model._weight.max() * 0.01
    df_learn.at[0, 'weight error'] = (gen_model.weight - inf_model.weight).abs().mean().item()
    adjacency_index_flattened_true = gen_model.log_adjacency.argmax(dim=-1).flatten() # (n_states, n_neurons, n_neurons)
    if method in ['GLM']:
        adjacency_flattened_pred = inf_model.adjacency.repeat((5, 1, 1, 1)).view((-1, 3)) # (n_states x n_neurons x n_neurons, 3)
    else:
        adjacency_flattened_pred = inf_model.adjacency.view((-1, 3)) # (n_states x n_neurons x n_neurons, 3)
    df_learn.at[0, 'adjacency log-likelihood'] = -F.cross_entropy(adjacency_flattened_pred, adjacency_index_flattened_true).item()
    df_learn.at[0, 'adjacency balanced accuracy'] = balanced_accuracy_score(adjacency_index_flattened_true, adjacency_flattened_pred.argmax(dim=-1))
    df_learn.at[0, 'adjacency accuracy'] = accuracy_score(adjacency_index_flattened_true, adjacency_flattened_pred.argmax(dim=-1))
    
    adj_prior_index_flattened_true = gen_model.adj_prior.argmax(dim=-1).flatten() - 1

    if method in ['GLM']:
        adj_prior_index_flattened_pred = utils.weight_to_adjacency_index(inf_model.weight).flatten()
    elif method in ['HMM-GLM', 'HMM-GLM L1', 'Engel', 'Ashwood', 'Scott HMM-GLM']:
        adj_prior_index_flattened_pred = utils.weight_to_adjacency_index(inf_model.weight.mean(dim=0)).flatten()
    elif method in ['Gaussian HMM-GLM', 'Gaussian HMM-GLM L1']:
        adj_prior_index_flattened_pred = utils.weight_to_adjacency_index(inf_model.w_prior).flatten()
    elif method == 'One-hot HMM-GLM':
        adj_prior_index_flattened_pred = inf_model.adj_prior.argmax(dim=-1).flatten() - 1

    df_learn.at[0, 'adjacency prior balanced accuracy'] = balanced_accuracy_score(adj_prior_index_flattened_true, adj_prior_index_flattened_pred)
    df_learn.at[0, 'adjacency prior accuracy'] = accuracy_score(adj_prior_index_flattened_true, adj_prior_index_flattened_pred)


df_inf.to_csv(f'csv/{method}_{trial}_inf.csv', index=False)
df_learn.to_csv(f'csv/{method}_{trial}_learn.csv', index=False)