import torch
from torch.utils.data import DataLoader
from datasets import concatenate_datasets
import argparse
import copy
from models import VariationalIRT
from logger import Logger
from data_utils import VIRTDataset, shuffle_examinee_collate_fn



def evaluate(model, test_size, test_dataloader, logger, best_metric):
    all_rec_y, all_kl_a, all_kl_d, all_loss_1 = 0, 0, 0, 0
    data = []
    with torch.no_grad():
        for step, batch in enumerate(test_dataloader):
            data.append(batch['y'])
            _, _, rec_y, kl_a, kl_d, loss_1 = model(batch['y'])
            all_rec_y += rec_y.item() * batch['y'].size(0)
            all_kl_a += kl_a.item() * batch['y'].size(0)
            all_kl_d += kl_d.item() * batch['y'].size(0)
            all_loss_1 += loss_1.item() * batch['y'].size(0)
        data = torch.cat(data, dim=0)
        a = model.predict_a(data)
        d = model.predict_d(data)
        acc = model.accuracy(data).item()

    all_rec_y /= test_size
    all_kl_a /= test_size
    all_kl_d /= test_size
    all_loss_1 /= test_size

    logger.print(f'[EVALUATION] Acc={acc}, Rec_y={all_rec_y}, KL_a={all_kl_a}, KL_d={all_kl_d}, Total={all_loss_1}')
    logger.print(f'[EVALUATION] a={a}')
    logger.print(f'[EVALUATION] d={d[:10]}')

    if best_metric and all_rec_y < best_metric:
        logger.print('[EVALUATION] New best model! Saving...\n')
        logger.save(model, 'virt_model.bin')
        return all_rec_y, True
    else:
        logger.print()
        return best_metric, False


def train(args):
    # initialization
    torch.cuda.manual_seed_all(args.seed)
    torch.manual_seed(args.seed)
    logger = Logger(args.proj_name)
    # build models
    model_virt = VariationalIRT()
    model_virt.train()
    optimizer_virt = torch.optim.Adam(model_virt.parameters(), lr=args.lr)
    logger.write_config(args, model_virt, args.proj_name)

    # load data
    train_set = VIRTDataset(data_path=args.data_path, type=args.data_type, filtered=args.filtered, split='train')
    test_set = VIRTDataset(data_path=args.data_path, type=args.data_type, filtered=args.filtered, split='test')
    train_dataloader = DataLoader(train_set, batch_size = args.batch_size, shuffle=True, num_workers=4, collate_fn=shuffle_examinee_collate_fn)
    test_dataloader = DataLoader(test_set, batch_size = args.batch_size, shuffle=False, num_workers=4)
    test_dataloader_final = DataLoader(train_set, batch_size = args.batch_size, shuffle=False, num_workers=4)

    # start training
    best_metric = 999
    best_model = None
    for epoch in range(args.epochs):
        for step, batch in enumerate(train_dataloader):
            optimizer_virt.zero_grad()
            _, _, rec_y, kl_a, kl_d, loss_1 = model_virt(batch['y'])
            loss_1.backward()
            optimizer_virt.step()

            print(args.proj_name)
            logger.print(f'[TRAIN] Epoch {epoch+1}/{args.epochs}, Iter {step+1}/{len(train_dataloader)}')
            logger.print(f'Rec_y={rec_y.item()}, KL_a={kl_a.item()}, KL_d={kl_d.item()}, Total={loss_1.item()}\n')

        model_virt.eval()
        best_metric, save = evaluate(model_virt, len(test_set), test_dataloader, logger, best_metric)
        if save:
            best_model = copy.deepcopy(model_virt.state_dict())
        model_virt.train()

    model_virt.load_state_dict(best_model)
    model_virt.eval()
    evaluate(model_virt, len(train_set), test_dataloader_final, logger, None)    


def model_structure(model):
    blank = ' '
    print('-' * 90)
    print('|' + ' ' * 11 + 'weight name' + ' ' * 10 + '|' \
          + ' ' * 15 + 'weight shape' + ' ' * 15 + '|' \
          + ' ' * 3 + 'number' + ' ' * 3 + '|')
    print('-' * 90)
    num_para = 0
    type_size = 4  

    for index, (key, w_variable) in enumerate(model.named_parameters()):
        if len(key) <= 30:
            key = key + (30 - len(key)) * blank
        shape = str(w_variable.shape)
        if len(shape) <= 40:
            shape = shape + (40 - len(shape)) * blank
        each_para = 1
        for k in w_variable.shape:
            each_para *= k
        num_para += each_para
        str_num = str(each_para)
        if len(str_num) <= 10:
            str_num = str_num + (10 - len(str_num)) * blank

        print('| {} | {} | {} |'.format(key, shape, str_num))
    print('-' * 90)
    print('The total number of parameters: ' + str(num_para))
    print('The parameters of Model {}: {:4f}M'.format(model._get_name(), num_para * type_size / 1000 / 1000))
    print('-' * 90)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='VIRT arguments')
    parser.add_argument('--proj_name', type=str, required=True, help='name of the project (log folder)')
    parser.add_argument('-t', '--data_type', type=str, required=True, choices=['bias', 'toxicity', 'justice', 'commonsense', 'virtue'], help='type of data')
    parser.add_argument('--filtered', type=bool, default=False, help='use filtered data')
    parser.add_argument('--data_path', type=str, default='', help='name of the project (log folder)')

    parser.add_argument('--seed', type=int, default=77, help='manual random seed')
    parser.add_argument('-n', '--n_examinees', type=int, default=36, help='number of examinees, =responses per item')

    parser.add_argument('-b', '--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
    parser.add_argument('-e', '--epochs', type=int, default=100, help='number of total epochs')
    args = parser.parse_args()
    # train(args)
    model_virt = VariationalIRT()
    model_structure(model_virt)
