import argparse
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
import pandas as pd
import numpy as np
from tqdm import tqdm


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="generations.pickle")
    parser.add_argument("--save_pickle", type=str, default="results.pickle")
    parser.add_argument("--save_csv", type=str, default="results.csv")
    parser.add_argument("--seed", type=int, default=11111)
    parser.add_argument("--num_seq", type=int, default=16)
    parser.add_argument("--seq_len", type=int, default=2048)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--num_layers", type=int, default=2)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--embedding_dim", type=int, default=16)
    parser.add_argument("--hidden_dim", type=int, default=8)
    return parser.parse_args()


def set_seed(seed=5775709):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def add_to_results(results, new_result):

    for k, v in new_result.items():
        if k in results:
            results[k].append(str(np.array(v).tolist()))
        else:
            results[k] = [str(np.array(v).tolist())]
    
    return results


def kl_divergence(logp, logq):
    p = torch.exp(logp)
    return (p * (logp - logq)).sum(-1)


def hellinger_distance(p, q):
    squared_hellinger = 1.0 - torch.sqrt(p * q).sum(-1)
    squared_hellinger = torch.clamp(squared_hellinger, min=0.0)
    return torch.sqrt(squared_hellinger)


class AutoregressiveLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x, hidden=None):
        # x shape: (batch_size, seq_len)
        x_embed = self.embedding(x)
        
        if hidden is None:
            lstm_out, hidden = self.lstm(x_embed)
        else:
            lstm_out, hidden = self.lstm(x_embed, hidden)

        logits = lstm_out[-1, :]
        logits = self.fc(logits)
        
        return logits, hidden


def train_eval_model(model, train_seq, label, epochs=1, lr=0.001):

    # init
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    device = torch.device("cuda")
    model = model.to(device)
    
    # train
    model.train()
    for epoch in tqdm(range(epochs)):
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0
        
        for i in range(len(train_seq)-1):
            
            x = train_seq[:i+1]
            target = train_seq[i+1]

            optimizer.zero_grad()
            logits, _ = model(x)

            loss = criterion(logits, target)
            
            # Backward and optimize
            loss.backward()
            optimizer.step()
            
            # Track metrics
            total_loss += loss.item()
            pred = torch.argmax(logits)
            correct_predictions += (pred == target).sum().item()
            total_predictions += 1
        
        # Calculate epoch metrics
        epoch_loss = total_loss / (len(train_seq) - 1)
        epoch_acc = correct_predictions / total_predictions
        
        # print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

    # compute the logits of the next action
    model.eval()
    with torch.no_grad():
        logits, _ = model(train_seq)
    
    return logits


def evaluate_lstm(args, emission_prob, state_seq, emission_seq, num_observation, place_to_eval):

    # prepare
    emission_prob = torch.from_numpy(emission_prob).to("cuda")
    emission_logprob = torch.log(emission_prob)
    place_to_eval = torch.tensor(place_to_eval).long().to("cuda")

    # evaluate
    lstm_emission_acc, lstm_emission_prob, lstm_emission_reverse_kl, lstm_emission_forward_kl, lstm_emission_hellinger_distance = [], [], [], [], []
    for idx in tqdm(range(len(emission_seq))):
        all_logprob = []
        for place in place_to_eval:
            # prepare batch
            x = torch.from_numpy(emission_seq[idx][:place]).long().to("cuda")
            y = torch.from_numpy(emission_seq[idx][place:place+1]).long().to("cuda")

            # init model
            model = AutoregressiveLSTM(num_observation, args.embedding_dim, args.hidden_dim, args.num_layers)

            # train and evaluate model
            logits = train_eval_model(model, x, y, epochs=args.epochs, lr=args.lr)
            logprob = F.log_softmax(logits, dim=-1)
            all_logprob.append(logprob)

        # compute probs
        all_logprob = torch.stack(all_logprob)
        all_prob = torch.exp(all_logprob)
        all_prob = all_prob / all_prob.sum(-1, keepdim=True)

        # gather labels
        batch_label = torch.tensor(emission_seq[idx]).long().to("cuda")[place_to_eval]
        batch_state_label = torch.tensor(state_seq[idx]).long().to("cuda")[place_to_eval]

        # compute accuracy
        predicted_emission = torch.argmax(all_prob, dim=-1)
        lstm_emission_acc.append(predicted_emission == batch_label)

        # compute prob
        prob = torch.gather(all_prob, 1, batch_label.unsqueeze(-1)).squeeze(-1)
        lstm_emission_prob.append(prob)

        # compute kl
        all_logprob = torch.log(all_prob)
        label_logprob_label = emission_logprob[batch_state_label]
        lstm_emission_reverse_kl.append(kl_divergence(all_logprob, label_logprob_label))
        lstm_emission_forward_kl.append(kl_divergence(label_logprob_label, all_logprob))
        lstm_emission_hellinger_distance.append(hellinger_distance(all_prob, emission_prob[batch_state_label]))

    all_results = [
        torch.stack(lstm_emission_acc).float().cpu().tolist(),
        torch.stack(lstm_emission_prob).float().cpu().tolist(),
        torch.stack(lstm_emission_reverse_kl).float().cpu().tolist(),
        torch.stack(lstm_emission_forward_kl).float().cpu().tolist(),
        torch.stack(lstm_emission_hellinger_distance).float().cpu().tolist(),
    ]

    lstm_emission_acc = torch.stack(lstm_emission_acc).float().mean(0).cpu().tolist()
    lstm_emission_prob = torch.stack(lstm_emission_prob).float().mean(0).cpu().tolist()
    lstm_emission_reverse_kl = torch.stack(lstm_emission_reverse_kl).float().mean(0).cpu().tolist()
    lstm_emission_forward_kl = torch.stack(lstm_emission_forward_kl).float().mean(0).cpu().tolist()
    lstm_emission_hellinger_distance = torch.stack(lstm_emission_hellinger_distance).float().mean(0).cpu().tolist()

    return {
        'lstm_emission_acc': lstm_emission_acc,
        'lstm_emission_prob': lstm_emission_prob,
        'lstm_emission_reverse_kl': lstm_emission_reverse_kl,
        'lstm_emission_forward_kl': lstm_emission_forward_kl,
        'lstm_emission_hellinger_distance': lstm_emission_hellinger_distance,
    }, all_results


def main():

    # init
    args = parse_arguments()
    set_seed(args.seed)

    SEQ_LENGTH = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]

    # load dataset
    with open(args.dataset, 'rb') as file:
        object_file = pickle.load(file)
    num_states, steady_states, lambda2s, Us, Sigmas, U_invs, As, A_entropys, num_observations, observations, hidden_states, Bs, B_entropys, pi_0s = object_file
    
    # evaluation
    results, lstm_all_results = {}, []
    for num_state, steady_state, lambda2, U, Sigma, U_inv, A, A_entropy, num_observation, observation, hidden_state, B, B_entropy, pi_0 in tqdm(zip(num_states, steady_states, lambda2s, Us, Sigmas, U_invs, As, A_entropys, num_observations, observations, hidden_states, Bs, B_entropys, pi_0s), total=len(num_states)):
        
        A = (np.array(A) / np.sum(A, axis=1, keepdims=True)).tolist()

        # record meta info
        meta_info = {}
        meta_info['num_state'] = num_state
        meta_info['steady_state'] = steady_state
        meta_info['lambda2'] = lambda2
        meta_info['U'] = U
        meta_info['Sigma'] = Sigma
        meta_info['U_inv'] = U_inv
        meta_info['A'] = A
        meta_info['A_entropy'] = A_entropy
        meta_info['num_observation'] = num_observation
        meta_info['B'] = B
        meta_info['B_entropy'] = B_entropy
        meta_info['pi_0'] = pi_0
        results = add_to_results(results, meta_info)

        # prep
        B = np.array(B)
        observation = np.array(observation)[:args.num_seq, :args.seq_len+1]
        hidden_state = np.array(hidden_state)[:args.num_seq, :args.seq_len+1]

        # record lstm_result
        lstm_result, lstm_all_result = evaluate_lstm(args, B, hidden_state, observation, num_observation, SEQ_LENGTH)
        results = add_to_results(results, lstm_result)
        lstm_all_results.append(lstm_all_result)
        print('done lstm')

    # Save the results to a CSV file
    df = pd.DataFrame(results)
    df.to_csv(args.save_csv, index=False)

        # save lstm_all_results and bw_all_results
    with open(args.save_pickle, 'wb') as f:
        pickle.dump(lstm_all_results, f)


if __name__ == "__main__":
    main()