import numpy as np
import torch
import torch.nn as nn
import os

from torch.utils.data import DataLoader, random_split
import argparse
import json

from dataset import TSDataset, TSMDataset
from Model import ShapeConvClassifier, ShapeConv, ShapeConvClassifierM, ShapeConvM
from utils import DBILoss, NMI

parser = argparse.ArgumentParser()
parser.add_argument('--cuda', type=int, default=0, help='cuda device id')
parser.add_argument('--dir', type=str, default='logs', help='directory to save logs')
parser.add_argument('--dataset', type=str, default='', help='dataset name')
parser.add_argument('--task', type=str, default='S', help='S: supervised, U: unsupervised')
parser.add_argument('--var', type=str, default='U', help='U: univeriate, M: multivariate')
parser.add_argument('--num_class', type=int, default=2, help='number of classes')
parser.add_argument('--init', type=str, default='None', help='initialization methods')
parser.add_argument('--l', type=float, default=3., help = "kernel size = dataset length / l")
parser.add_argument('--d', type=float, default=2., help = 'dim = num_class * d')
parser.add_argument('--hid', type=int, default=32, help = 'hidden layer size')
parser.add_argument('--epochs', type=int, default=1000, help = 'number of epochs')
parser.add_argument('--lr', type=float, default=0.05, help = 'learning rate')
parser.add_argument('--lam_div', type=float, default=0.01, help = 'lambda for diverse regularizer')
parser.add_argument('--lam_shape', type=float, default=0.01, help = 'lambda for shape regularizer')
parser.add_argument('--batch_size', type=int, default=32, help = 'batch size')
parser.add_argument('--seed', type=int, default=10, help = 'random seed')


args = parser.parse_args()

device = torch.device("cuda:"+str(args.cuda))
models_dir = args.dir
if not os.path.exists(models_dir):
    save_dir = os.path.join(models_dir, '0')
    os.makedirs(save_dir)
else:
    existing_dirs = np.array(
            [
                d
                for d in os.listdir(models_dir)
                if os.path.isdir(os.path.join(models_dir, d))
                ]
    ).astype(int)
    if len(existing_dirs) > 0:
        dir_id = str(existing_dirs.max() + 1)
    else:
        dir_id = "1"
    save_dir = os.path.join(models_dir, dir_id)
    os.makedirs(save_dir)

with open(os.path.join(save_dir, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=2)

if args.var == 'U':
    train_dataset = TSDataset(args.dataset, 'TRAIN')
    test_dataset = TSDataset(args.dataset, 'TEST')
else:
    train_dataset = TSMDataset(args.dataset, 'TRAIN')
    test_dataset = TSMDataset(args.dataset, 'TEST')

train_len = int(len(train_dataset) * 0.8)
val_len = int(len(train_dataset) - train_len)

train_dataset, val_dataset = random_split(train_dataset, [train_len, val_len], generator=torch.Generator().manual_seed(args.seed))
test_dataset = random_split(test_dataset, [len(test_dataset)])[0]

train_loader = DataLoader(train_dataset, args.batch_size)
val_loader = DataLoader(val_dataset, args.batch_size)
test_loader = DataLoader(test_dataset, args.batch_size)

if args.task == 'S' and args.var == 'U':
    seq_length = train_dataset[0][0].shape[0]
    model = ShapeConvClassifier(
        num_class = args.num_class, 
        hid = args.hid, 
        dim = int(args.num_class * args.d), 
        kernel_size = int(seq_length / args.l)
    ).to(device)
elif args.task == 'S' and args.var == 'M':
    seq_length = train_dataset[0][0].shape[1]
    model = ShapeConvClassifierM(
        num_class = args.num_class, 
        channel = train_dataset[0][0].shape[0], 
        hid = args.hid, 
        dim = int(args.num_class * args.d), 
        kernel_size = int(seq_length / args.l)
    ).to(device)
elif args.task == 'U' and args.var == 'U':
    seq_length = train_dataset[0][0].shape[0]
    model = ShapeConv(
        dim = int(args.num_class * args.d), 
        kernel_size = int(seq_length / args.l)
    ).to(device)
else:
    raise NotImplementedError

if args.init == 'normal' or args.init == 'None':
    pass
elif args.init == 'sample':
    model.init_cluster(train_loader, True)
elif args.init == 'allsample':
    model.init_cluster(train_loader, False, num_class = args.num_class)
elif args.init == 'cut':
    model.init_cut(train_loader, args.num_class)
elif args.init == 'class':
    model.init(train_loader)
else:
    raise NotImplementedError

model.to(device)

epochs = args.epochs
optimizer = torch.optim.Adam(model.parameters(), lr = args.lr)
lossfn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

vacc = 0
verb = []

if args.task == 'S':
    criterion = nn.CrossEntropyLoss()
    hist = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    def train_step(loader, val = False):
        tot_loss = []
        tot_acc = []
        for i, (data, label) in enumerate(loader):
            optimizer.zero_grad()
            if args.var == 'U':
                data = data.unsqueeze(1).to(device)
            else:
                data = data.to(device)
                data = torch.nan_to_num(data)
            label = label.type(torch.LongTensor).to(device)
            # print(data.device, model.device)
            h, knorm, out = model(data)
            # print(out, label)
            loss = (criterion(out.squeeze(), label) + \
                    args.lam_shape * h.min(-1)[0].sum(-1).mean() + \
                    args.lam_div * knorm)
            if not val:
                loss.backward()
                optimizer.step()
            tot_loss.append(loss.item() * label.shape[0])
            tot_acc.append(torch.sum(out.squeeze().max(-1)[1] == label).cpu().numpy())
        return np.sum(tot_loss) / len(loader.dataset), np.sum(tot_acc) / len(loader.dataset)
    
    for epoch in range(epochs):
        model.train()
        train_loss, train_acc = train_step(train_loader)
        model.eval()
        with torch.no_grad():
            val_loss, val_acc = train_step(val_loader, True)
            if vacc <= val_acc:
                vacc = max(vacc, val_acc)
                torch.save(model.state_dict(), os.path.join(save_dir, 'best.pt'))
        hist['train_loss'].append(train_loss)
        hist['train_acc'].append(train_acc)
        hist['val_loss'].append(val_loss)
        hist['val_acc'].append(val_acc)
        verb.append('Epoch: {}, Train Loss: {:.4f}, ' \
                    'Train Acc: {:.4f}, Val Loss: {:.4f}, ' \
                    'Val Acc: {:.4f}'.format(epoch, train_loss, 
                                            train_acc, val_loss, 
                                            val_acc)
                    )
        print(verb[-1])
    
    with open(os.path.join(save_dir, 'logs.txt'), 'w') as f:
        for l in verb:
            f.write(l + '\n')

    model.load_state_dict(torch.load(os.path.join(save_dir, 'best.pt')))
    with torch.no_grad():
        test_loss, test_acc = train_step(test_loader, True)

    print("Best Val Acc: {:.4f}".format(vacc))
    print("Test Acc: {:.4f}".format(test_acc))

else:
    hist = {'train_loss': [], 'train_nmi': [], 'val_loss': [], 'val_nmi': []}

    def train_step(loader, val = False):
        tot_loss = []
        for i, (data, label) in enumerate(loader):
            optimizer.zero_grad()
            data = data.unsqueeze(1).to(device)
            out, knorm = model(data)
            loss = (DBILoss(out.min(-1)[0], args.num_class) + \
                    args.lam_shape * out.min(-1)[0].sum(-1).mean() + \
                    args.lam_div * knorm)
            if not val:
                loss.backward()
                optimizer.step()
            tot_loss.append(loss.item() * label.shape[0])
        return np.sum(tot_loss) / len(train_dataset)
    
    for epoch in range(epochs):
        model.train()
        train_loss = train_step(train_loader)
        model.eval()
        with torch.no_grad():
            val_loss = train_step(val_loader, True)
        train_nmi, val_nmi = NMI(model, args.num_class, train_dataset, test_dataset, device)
        if vacc <= val_nmi:
            vacc = max(vacc, val_nmi)
            torch.save(model.state_dict(), os.path.join(save_dir, 'best.pt'))
        hist['train_loss'].append(train_loss)
        hist['train_nmi'].append(train_nmi)
        hist['val_loss'].append(val_loss)
        hist['val_nmi'].append(val_nmi)
        verb.append('Epoch: {}, Train Loss: {:.4f}, '\
                    'Train NMI: {:.4f}, Val Loss: {:.4f}, '\
                    'Val NMI: {:.4f}'.format(epoch, train_loss, 
                                            train_nmi, val_loss, 
                                            val_nmi)
                    )
        print(verb[-1])
    
        with open(os.path.join(save_dir, 'logs.txt'), 'w') as f:
            for l in verb:
                f.write(l + '\n')

    model.load_state_dict(torch.load(os.path.join(save_dir, 'best.pt')))
    with torch.no_grad():
        test_loss = train_step(test_loader, True)
    test_nmi = NMI(model, args.num_class, train_dataset, test_dataset, device, True)

    print("Best Val NMI: {:.4f}".format(vacc))
    print("Test NMI: {:.4f}".format(test_nmi))