import itertools
import random
import torch
import tqdm
import time
import copy
import sys
import os
import argparse
import numpy as np
import pandas as pd
import torch.nn.functional as F
from collections import Counter
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix
from torch.utils.data import DataLoader, Subset
from FaCTHMM import FaCTHMM_model
from utils.dataset_facthmm import *


def train(model, optimizer, data_loader):
    total_loss = 0
    total_samples = 0
    device = next(model.parameters()).device
    model.train()
    for s in data_loader: 
        n_obs = s[0].shape[0] #s[1].sum().item()
        loss = model([s[0].to(device), s[1].tolist()]) / n_obs
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            total_samples+=n_obs
            total_loss+=loss.item()*n_obs
            model.q_param.data.clamp_(min=1e-6)
            model.emit.sqrt2_sigma_inv.data.clamp_(min=1e-6)
    return total_loss/total_samples


def test(model, data_loader, samp_size):
    total_loss = 0
    total_samples = 0
    crps_sum = 0 
    all_targets, all_predicts = list(), list()
    device = next(model.parameters()).device
    model.eval()
    with torch.no_grad():
        for s in data_loader:
            loss = model([s[0].to(device), s[1].tolist()])
            total_loss+=loss.item()
            total_samples+=s[0].shape[0]
            crps_sum += model.sum_batch_crps(samp_size,[s[0].to(device), s[1].tolist()])
    return total_loss/total_samples, crps_sum/total_samples


def decoding(model, data_loader, model_path):
    device = next(model.parameters()).device
    model.eval()
    decoded = []
    with torch.no_grad():
        for s in data_loader:
            _, decoded_ = model.decode([s[0].to(device), s[1].tolist()])
            decoded_ = s._replace(data=decoded_.cpu())
            decoded.extend(torch.nn.utils.rnn.unpack_sequence(decoded_))
        decoded = torch.cat(decoded, 0).numpy()
        decoded = pd.DataFrame(decoded)
        decoded.columns = ["state"]
        decoded.to_csv(model_path + "_facthmm_decoded" + str(model.n_bits) + ".csv", index=False)


def main(args):
    device = torch.device(args.device)
    dataset = get_data_n_stats(args.dataset_path)
    train_data_loader = CustomLoader(max_obs=args.batch_size, dataset=dataset, mode="train", shuffle=True, progress_bar=True)
    test_data_loader = CustomLoader(max_obs=args.batch_size, dataset=dataset, mode="test", shuffle=False, progress_bar=True)
    model_path = os.path.join(args.model_path,
                              "{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}_{8}.pt".format(
                                args.n_bits, args.n_flows, args.n_h, args.flow_layers, 
                                args.base_layers, args.lr, args.wdcy, args.epoch, args.optim))
    model = FaCTHMM_model(data_dim=dataset.data_dim, n_bits=args.n_bits, n_flows=args.n_flows, n_h=args.n_h, 
                              flow_layers=args.flow_layers, base_layers=args.base_layers, dataset_stats=dataset.dataset_stats).to(device)          
    if args.test:
        args.epoch = 0
        epoch_i = 0
    if args.optim=="adam":
        optimizer = torch.optim.AdamW([{'params': model.flow.parameters()}, 
                                       {'params': model.emit.sqrt2_sigma_inv},
                                       {'params': model.emit.mu, 'weight_decay': 0},
                                       {'params': model.p_h_0, 'weight_decay': 0.},
                                       {'params': model.q_param, 'weight_decay': 0.}
                                       ], lr=args.lr, weight_decay=args.wdcy, amsgrad=True)
    else:
        optimizer = torch.optim.SGD([{'params': model.flow.parameters()},
                                     {'params': model.emit.sqrt2_sigma_inv},
                                     {'params': model.emit.mu, 'weight_decay': 0},
                                     {'params': model.p_h_0, 'weight_decay': 0.},
                                     {'params': model.q_param, 'weight_decay': 0.}
                                     ], lr=args.lr, weight_decay=args.wdcy)      
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9)
    best_loss = 1e10 
    test_loss = 1e10 
    best_epoch = 0
    stop_cnt = 0
    if args.save:
        os.makedirs(args.model_path, exist_ok=True) 
    start_time = time.time()
    for _ in range(args.epoch):
        print("epoch:", _)
        tr_loss = train(model, optimizer, train_data_loader)
        print("train_logloss:",tr_loss)
        if tr_loss<best_loss:
            stop_cnt = 0
            best_loss = tr_loss
            epoch_i = _ + 1
            if args.save and _ > args.epoch-10:
                torch.save(model.state_dict(), model_path)
        elif stop_cnt<2:
            stop_cnt+=1
        else:
            stop_cnt = 0
            scheduler.step()
    print("--- %s seconds training time---" % (time.time() - start_time))
    print("MAX memory used:", torch.cuda.max_memory_allocated())
    if args.save:
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
    start_time = time.time()
    tr_result = test(model, train_data_loader, args.samp_size)
    te_result = test(model, test_data_loader, args.samp_size)
    print(tr_result[0], tr_result[1], te_result[0], te_result[1])

    print("--- %s seconds testing time---" % (time.time() - start_time))
    if "scar" in args.dataset_path:
        start_time = time.time()
        decoding(model, test_data_loader, model_path)
        print("--- %s seconds decoding time---" % (time.time() - start_time))
    print("epoch: {0}, train_loss: {1}, train_crps: {2}, test_loss: {3}, test_crps:{4}".format(epoch_i, tr_result[0], tr_result[1], te_result[0], te_result[1]))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train and test the model")
    parser.add_argument('--dataset-name', type=str, default="taxi", help="name of dataset")
    parser.add_argument('--n-bits', type=int, default=2, help="number of HMM chains")
    parser.add_argument('--n-flows', type=int, default=8, help="number of normalising flows")
    parser.add_argument('--n-h', type=int, default=8, help="hidden units")
    parser.add_argument('--flow-layers', type=int, default=3, help="number of normalising flow layers")
    parser.add_argument('--base-layers', type=int, default=3, help="number of basenet layers in realnvp. >=2")
    parser.add_argument('--samp-size', type=int, default=50, help="number of samples for crps")
    parser.add_argument('--batch-size', type=int, default=20002, help="batch-size")  
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--wdcy', type=float, default=1e-12, help="weight decay")
    parser.add_argument('--epoch', type=int, default=100, help="epoch")  
    parser.add_argument('--optim', type=str, default="adam", help="optimiser")
    parser.add_argument('--save', type=int, default=1, help="save model")
    parser.add_argument('--device', type=str, default="cuda", help="device")
    parser.add_argument('--test', type=int, default=0, help="test trained model (test only)")
    parser.add_argument('--test-dir', type=str, default='', help="test dir")

    args = parser.parse_args()

    model_path_base = ''
    data_path_base = args.test_dir
    data_path_dict = {"taxi":"taxi.csv", "scar":"scar.csv", "lrff":"lrff.csv"}
    
    args.n_flows = min(2**args.n_bits, args.n_flows)
    args.dataset_path = data_path_base + data_path_dict[args.dataset_name]
    args.model_path = model_path_base + args.dataset_name

    print(args)    
    main(args)

    
