import os
import random
import numpy as np
import argparse
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import visdom

from utility import PGMdataset, RAVENdataset, ToTensor
from sran import SRAN

parser = argparse.ArgumentParser(description='our_model')
parser.add_argument('--model', type=str, default='SRAN')
parser.add_argument("--actfun", type=str, default="relu")
parser.add_argument("--feature_width", type=float, default=1.)
parser.add_argument('--dataset', type=str, default='I-RAVEN', choices=['PGM', 'I-RAVEN', 'RAVEN'])
parser.add_argument('--img_size', type=int, default=224)
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--print_freq', type=int, default=10)
parser.add_argument('--seed', type=int, default=12345)
parser.add_argument('--load_workers', type=int, default=16)
parser.add_argument('--resume', type=str, default='')
#parser.add_argument('--dataset_path', type=str, default='/media/dsg3/datasets/PGM')
parser.add_argument('--dataset_path', type=str, default='/media/dsg3/datasets/I-RAVEN')
parser.add_argument('--save', type=str, default='/media/dsg3/hs')
parser.add_argument("--num_checkpoints", type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--weight_decay', type=float, default=0.0)
parser.add_argument('--meta_beta', type=float, default=0.0)
parser.add_argument('--visdom', action='store_true', help='Use visdom for visualization')
parser.add_argument("--no-cuda", dest="cuda", action="store_false")
parser.add_argument('--debug', action='store_true')

args = parser.parse_args()

print("Arguments:")
for arg_name in vars(args):
    print("{} = {}".format(arg_name, repr(getattr(args, arg_name))))
print("")

# Set RNG seeds
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

if args.debug:
    args.visdom = False

if not os.path.exists(args.save):
    os.makedirs(args.save)

if args.dataset == 'PGM':
    trainset = PGMdataset(args.dataset_path, "train", args.img_size, transform=transforms.Compose([ToTensor()]), shuffle = True)
    validset = PGMdataset(args.dataset_path, "val", args.img_size, transform=transforms.Compose([ToTensor()]))
    testset = PGMdataset(args.dataset_path, "test", args.img_size, transform=transforms.Compose([ToTensor()]))
elif "RAVEN" in args.dataset:
    args.train_figure_configurations = [0,1,2,3,4,5,6]
    args.val_figure_configurations = args.train_figure_configurations
    args.test_figure_configurations = [0,1,2,3,4,5,6]
    trainset = RAVENdataset(args.dataset_path, "train", args.train_figure_configurations, args.img_size, transform=transforms.Compose([ToTensor()]), shuffle = True)
    validset = RAVENdataset(args.dataset_path, "val", args.val_figure_configurations, args.img_size, transform=transforms.Compose([ToTensor()]))
    testset = RAVENdataset(args.dataset_path, "test", args.test_figure_configurations, args.img_size, transform=transforms.Compose([ToTensor()]))
else:
    raise ValueError("Unsupported dataset: {}".format(args.dataset))

trainloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.load_workers)
validloader = DataLoader(validset, batch_size=args.batch_size, shuffle=False, num_workers=args.load_workers)
testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.load_workers)

print ('Dataset:', args.dataset)
print ('Train/Validation/Test:{0}/{1}/{2}'.format(len(trainset), len(validset), len(testset)))
print ('Image size:', args.img_size)

if args.model == 'SRAN':
    model = SRAN(args)
else:
    raise ValueError("Unsupported model: {}".format(args.model))

with open(os.path.join(args.save, 'results.log'), 'w') as f:
    for key, value in vars(args).items():
        f.write('{0}: {1}\n'.format(key, value))
    f.write('--------------------------------------------------\n')

pmodel = torch.nn.DataParallel(model)
torch.backends.cudnn.benchmark = True
pmodel = pmodel.cuda()

# optionally resume from a checkpoint
start_epoch = 0
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint["epoch"] + 1
        model.load_state_dict(checkpoint["state_dict"])
        model.optimizer.load_state_dict(checkpoint["optimizer"])
        print(
            "=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]
            )
        )
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

if args.visdom:
    viz = visdom.Visdom(port = 9527, env = args.dataset)

def count_parameters(model, only_trainable=True):
    r"""
    Count the number of (trainable) parameters within a model and its children.
    Arguments:
        model (torch.nn.Model): the model.
        only_trainable (bool, optional): indicates whether the count should be restricted
            to only trainable parameters (ones which require grad), otherwise all
            parameters are included. Default is ``True``.
    Returns:
        int: total number of (trainable) parameters possessed by the model.
    """
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())

def train(epoch):
    model.train()
    train_loss = 0
    accuracy = 0
    loss_all = 0.0
    acc_all = 0.0
    counter = 0
    for batch_idx, (image, target, meta_target) in enumerate(trainloader):
        counter += 1
        if args.cuda:
            image = image.cuda()
            target = target.cuda()
            meta_target = meta_target.cuda()
        model.optimizer.zero_grad()
        output = pmodel(image)
        loss = model.compute_loss(output, target, meta_target)
        loss.backward()
        model.optimizer.step()
        pred = output[0].data.max(1)[1]
        correct = pred.eq(target.data).cpu().sum().numpy()
        accuracy = correct * 100. / target.size()[0]
        loss, acc = loss.item(), accuracy
        if batch_idx <= 2 or batch_idx % args.print_freq == 0:
            print(
                "Train: Epoch:{:4d}, Batch:{:4d}, Loss:{:9.6f}, Acc:{:6.2f}.".format(
                    epoch, batch_idx, loss, acc
                ),
                flush=True,
            )
        loss_all += loss
        acc_all += acc
    if counter > 0:
        print("Avg Training Loss: {:.6f}".format(loss_all/float(counter)), flush=True)
    return loss_all/float(counter), acc_all/float(counter)

def validate(epoch):
    model.eval()
    accuracy = 0
    acc_all = 0.0
    counter = 0
    with torch.no_grad():
        for batch_idx, (image, target, meta_target) in enumerate(validloader):
            counter += 1
            if args.cuda:
                image = image.cuda()
                target = target.cuda()
                meta_target = meta_target.cuda()
            output = pmodel(image)
            pred = output[0].data.max(1)[1]
            correct = pred.eq(target.data).cpu().sum().numpy()
            accuracy = correct * 100. / target.size()[0]
            acc = accuracy
            acc_all += acc
    if counter > 0:
        print("Total Validation Acc: {:.4f}".format(acc_all/float(counter)), flush=True)
    return acc_all/float(counter)

def test(epoch):
    model.eval()
    accuracy = 0
    acc_all = 0.0
    counter = 0
    with torch.no_grad():
        for batch_idx, (image, target, meta_target) in enumerate(testloader):
            counter += 1
            if args.cuda:
                image = image.cuda()
                target = target.cuda()
                meta_target = meta_target.cuda()
            output = pmodel(image)
            pred = output[0].data.max(1)[1]
            correct = pred.eq(target.data).cpu().sum().numpy()
            accuracy = correct * 100. / target.size()[0]
            acc = accuracy
            acc_all += acc
    if counter > 0:
        print("Total Testing Acc: {:.4f}".format(acc_all / float(counter)), flush=True)
    return acc_all/float(counter)

def main():
    print("Model architecture:")
    print(model)
    print()
    print(
        "Number of model parameters (total):     {}".format(
            count_parameters(model, only_trainable=False)
        )
    )
    print(
        "Number of model parameters (trainable): {}".format(
            count_parameters(model, only_trainable=True)
        )
    )
    for submodel_name in ("conv1", "conv2", "conv3", "h1", "h2", "h3"):
        if not hasattr(model, submodel_name):
            continue
        print(
            "model.{} has {} parameters, of which {} are trainable".format(
                submodel_name,
                count_parameters(getattr(model, submodel_name), only_trainable=False),
                count_parameters(getattr(model, submodel_name), only_trainable=True),
            )
        )
    print()

    if args.visdom:
        vis_title =  args.dataset + ' ' + start_time
        vis_legend = ['Train Acc', 'Val Acc', 'Test Acc']
        epoch_plot = create_vis_plot('Epoch', 'Acc', vis_title, vis_legend)

    results_fname = os.path.join(args.save, "results.csv")
    if start_epoch == 0:
        with open(results_fname, "w") as f:
            f.write("dataset,model,nparams,actfun,feature_width,seed,batch_size,lr,beta1,beta2,epsilon,meta_beta,epoch,loss_train,acc_train,acc_val,acc_test,weight_decay\n")

    for epoch in range(start_epoch, args.epochs):
        avg_train_loss, avg_train_acc = train(epoch)
        avg_val_acc = validate(epoch)
        avg_test_acc = test(epoch)
        model.save_model(args.save, epoch, remove_old=args.num_checkpoints)
        with open(os.path.join(args.save, 'results.log'), 'a') as f:
            f.write(
                'Epoch {}, Training loss: {:.6f}, Validation Acc: {:.4f}, Testing Acc: {:.4f}\n'.format(
                    epoch, avg_train_loss, avg_val_acc, avg_test_acc
                )
            )
        with open(results_fname, "a") as f:
            f.write(
                "{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}\n".format(
                    args.dataset,
                    args.model,
                    count_parameters(model, only_trainable=True),
                    args.actfun,
                    args.feature_width,
                    args.seed,
                    args.batch_size,
                    args.lr,
                    args.beta1,
                    args.beta2,
                    args.epsilon,
                    args.meta_beta,
                    epoch,
                    avg_train_loss,
                    avg_train_acc,
                    avg_val_acc,
                    avg_test_acc,
                    args.weight_decay,
                )
            )
        if args.visdom:
            viz.line(
                X=torch.ones((1, 3)) * epoch,
                Y=torch.Tensor([avg_train_acc, avg_val_acc, avg_test_acc]).unsqueeze(0),
                win=epoch_plot,
                update='append'
            )

def create_vis_plot(_xlabel, _ylabel, _title, _legend):
    return viz.line(
        X=torch.zeros((1,)),
        Y=torch.zeros((1, 3)),
        opts=dict(
            xlabel=_xlabel,
            ylabel=_ylabel,
            title=_title,
            legend=_legend
        )
    )
if __name__ == '__main__':
    main()
