import itertools
import random
import torch
import tqdm
import glob
import time
import copy
import sys
import os
import argparse
import numpy as np
import torch.nn.functional as F
from collections import Counter
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix
from torch.utils.data import DataLoader, Subset
from utils.dataset_facthmm import *
np.set_printoptions(linewidth=np.inf)


class nn_clf(torch.nn.Module):

    def __init__(self, n_bits, n_h, layers, n_class):
        super().__init__()
        self.net = torch.nn.ModuleList()
        for i in range(layers):
            self.net.append(torch.nn.Sequential(torch.nn.Linear(n_bits, n_h), torch.nn.ReLU(),
                                                torch.nn.Linear(n_h, n_h), torch.nn.ReLU(),
                                                torch.nn.Linear(n_h, n_bits)))
        self.net.append(torch.nn.Linear(n_bits, n_class))
        
    def forward(self, x):
        for _ in self.net[:-1]:
            x = x + _(x)
        return self.net[-1](x)

def train(model, optimizer, data_loader, bit_base):
    device = next(model.parameters()).device
    weight = None 
    total_loss = 0
    bcnt = 0
    model.train()    
    for s in data_loader: 
        with torch.no_grad():
            bn_list = s[1].tolist()
            s, targets = s[0][:,:-1].int().to(device).bitwise_and(bit_base).ne(0).float(), s[0][:,-1].long().to(device)
            if data_loader.dataset.task!=4: #task4: breeding stage
                seq_id = torch.cat(list(map(torch.arange,bn_list))).to(device)
                s = s.new_zeros((bn_list[0], len(bit_base))).index_add(0,seq_id,s)
                targets = targets[:bn_list[0]]
            s = s / (s.sum(dim=-1,keepdim=True) + 1e-7)
        bcnt+=1
        preds = model(s)
        loss = F.cross_entropy(preds,targets,weight=weight)
        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()
    return total_loss/bcnt


def test(model, data_loader, bit_base):
    total_loss = 0
    total_samples = 0
    all_targets, all_predicts = list(), list()
    device = next(model.parameters()).device
    model.eval()
    with torch.no_grad():
        for s in data_loader:
            bn_list = s[1].tolist()
            s, targets = s[0][:,:-1].int().to(device).bitwise_and(bit_base).ne(0).float(), s[0][:,-1].long().to(device)
            if data_loader.dataset.task!=4: #task4: breeding stage
                seq_id = torch.cat(list(map(torch.arange,bn_list))).to(device)
                s = s.new_zeros((bn_list[0], len(bit_base))).index_add(0,seq_id,s)
                targets = targets[:bn_list[0]]
            s = s / (s.sum(dim=-1,keepdim=True) + 1e-7)
            preds = model(s)
            loss = F.cross_entropy(preds,targets,reduction="sum")
            total_loss+=loss.item()
            total_samples+=s.shape[0]
            all_targets.extend(targets.tolist())
            all_predicts.extend(F.softmax(preds,1).tolist())
    all_predicts = np.array(all_predicts)
    
    print(np.unique(all_predicts.argmax(1)))
    print(np.unique(all_targets))
    print(confusion_matrix(all_targets, all_predicts.argmax(1)))
    if data_loader.dataset.n_class==2:
        return total_loss/total_samples, roc_auc_score(all_targets, all_predicts[:,1]), accuracy_score(all_targets, all_predicts.argmax(1))
    else:
        return total_loss/total_samples, roc_auc_score(all_targets, all_predicts, multi_class="ovr"), accuracy_score(all_targets, all_predicts.argmax(1))


def main(args):
    if args.task not in [1,2,3,4]:
        raise ValueError('Task should be in {1,2,3,4}.')
    device = torch.device(args.device)
    dataset = get_data_n_stats(args.dataset_path, args.task, decoded_path=args.decoded_path)
    train_data_loader = CustomLoader(max_obs=args.batch_size, dataset=dataset, mode="train", shuffle=True, progress_bar=False)
    test_data_loader = CustomLoader(max_obs=args.batch_size, dataset=dataset, mode="test", shuffle=False, progress_bar=False)
    model = nn_clf(args.n_bits, args.n_h, 3, dataset.n_class).to(device) 
    bit_base = (2**torch.arange(args.n_bits-1,-1,-1)).to(device)
    if args.optim=="adam":
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdcy, amsgrad=True)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wdcy)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9)
    best_loss = 1e10 
    test_loss = 1e10 
    best_epoch = 0
    stop_cnt = 0 
    for epoch_i in range(args.epoch):
        #print("epoch:", epoch_i)
        tr_loss = train(model, optimizer, train_data_loader, bit_base)
        #print("train_logloss:",tr_loss)
        if tr_loss<best_loss:
            stop_cnt = 0
            best_loss = tr_loss 
            best_model = copy.deepcopy(model)
        elif stop_cnt<2:
            stop_cnt+=1
        else:
            stop_cnt = 0
            scheduler.step()
    del model
    tr_result = test(best_model, train_data_loader, bit_base)
    te_result = test(best_model, test_data_loader, bit_base)
    print("epoch: {0}, train_loss: {1}, train_auc: {2}, train_acc: {3}".format(epoch_i+1, tr_result[0], tr_result[1], tr_result[2]))
    print("epoch: {0}, test_loss: {1}, test_auc: {2}, test_acc: {3}".format(epoch_i+1, te_result[0], te_result[1], te_result[2]))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train and test the model")
    parser.add_argument('--model-name', type=str, default="facthmm", help="name of dataset")
    parser.add_argument('--task', type=int, default=1, help="type of task")
    parser.add_argument('--n-bits', type=int, default=2, help="number of HMM chains")
    parser.add_argument('--t-resolution', type=int, default=1, help="granuity of disretisation")
    parser.add_argument('--n-h', type=int, default=256, help="hidden units")
    parser.add_argument('--batch-size', type=int, default=20002, help="batch-size")  
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--wdcy', type=float, default=0, help="weight decay")
    parser.add_argument('--epoch', type=int, default=50, help="epoch")  
    parser.add_argument('--optim', type=str, default="adam", help="optimiser")
    parser.add_argument('--device', type=str, default="cuda", help="device")
    parser.add_argument('--dataset-path', type=str, default="scar.csv", help="processed dataset path")

    args = parser.parse_args()
 
    args.decoded_path = glob.glob("*facthmm_decoded" + str(args.n_bits) + ".csv")[0]  #decoded hidden states file path
    print(args)    
    start_time = time.time()
    main(args)
    print("--- %s seconds ---" % (time.time() - start_time))

    
