import os
from argparse import ArgumentParser
import torch.optim as optim
from networks import *
from training_encoder import *
from data_utils.dataloader import *
from evaluate import *
import torch
import numpy as np

def get_args():
    parser = ArgumentParser(description='SymmetryReg')
    parser.add_argument('--exp_name', type=str, default='Experiment_name')
    parser.add_argument('--model_name', type=str, default='drlim',
                        help='Available options are 1) symmetryreg 2) autoencoder 3)drlim')
    parser.add_argument('--dataset', type=str, default='chair',
                        help='Available options are 1) chair')
    parser.add_argument('--data_dir', type=str, default='~/scratch/Chair_SymmetryReg'
                        )
    parser.add_argument('--continue_training', type=int, default=0)
    parser.add_argument('--use_action_pred', type=int, default=0)
    parser.add_argument('--param_type', type=str, default='Lie')
    parser.add_argument('--workers', type=int, default=4)
    parser.add_argument('--seed', type=int, default=0)


    # latent config
    parser.add_argument('--code_size', type=int, default=16)    #use 16 for chair
    parser.add_argument('--proj_size', type=int, default=3)

    # network config
    parser.add_argument('--num_channels', type=int, default=128)
    parser.add_argument('--mlp_hidden_dim', type=int, default=256)

    # training config
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--steps_per_epoch', type=int, default=200)
    parser.add_argument('--sample_training_data', type=int, default=0)
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-4)
    parser.add_argument('--train_cluster_size', type=int, default=32)
    parser.add_argument('--train_num_actions', type=int, default=4)
    parser.add_argument('--train_batch_size', type=int, default=64) #set this to 1 for SymmetryReg

    # loss config
    parser.add_argument('--hinge_thresh', type=float, default=5) #set this 5 for SymmetryReg
    parser.add_argument('--barrier_type', type=str, default='log')
    parser.add_argument('--barrier_coef', type=float, default=1)
    parser.add_argument('--cosine_sim', type=int, default=0)
    parser.add_argument('--conformal_map', type=bool, default=False)
    parser.add_argument('--rotation_map', type=bool, default=False)
    parser.add_argument('--decompositions', type=int, default=1)
    parser.add_argument('--temperature', type=float, default=0.05)

    return parser.parse_args()


if __name__ == '__main__':

    args = get_args()
    args.results_dir = 'results/' + args.model_name + '/' + args.exp_name
    args.checkpoint_dir = 'checkpoints/' + args.model_name + '/' + args.exp_name
    if not os.path.isdir(args.results_dir):
        os.makedirs(args.results_dir)
    if not os.path.isdir(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    args.img_w, args.img_h, args.image_channels = 48, 48, 3

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    enc = get_encoder(args)
    enc.to(device)
    if args.use_action_pred:
        action_pred = ActionPredictorSO3(args)
        action_pred.to(device)
    else:
        action_pred = None
    if not os.path.exists(os.path.join(args.checkpoint_dir, 'model_final.tar')) or args.continue_training:
        if args.model_name is not 'drlim' and not args.sample_training_data:
            train_dataloader = get_train_dataloaders(args)
            args.img_w = iter(train_dataloader).next()[0][0].shape[-1]
            args.img_h = args.img_w
        elif args.model_name == 'drlim':
            if not args.sample_training_data:
                train_dataset = ChairDRLIMDataset(args.data_dir)
            else:
                train_dataset = ChairSampleDataset(args.data_dir, args.steps_per_epoch * args.train_batch_size)
            train_dataloader = DataLoader(
                train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.workers, drop_last=True
            )
        else:
            train_dataloader = [i for i in range(args.steps_per_epoch)]

        if os.path.exists(os.path.join(args.checkpoint_dir, 'model_recent.tar')) and args.continue_training:
            enc.load_state_dict(torch.load(
                os.path.join(args.checkpoint_dir, 'model_recent.tar')
            ))
            print('Encoder Checkpoint Loaded for continued training ....')

        if args.model_name == 'symmetryreg':
            if args.use_action_pred:
                params = list(action_pred.parameters()) + list(enc.parameters())
            else:
                params = enc.parameters()
            opt = optim.Adam(params, lr=args.learning_rate, weight_decay=args.weight_decay)
            scheduler = optim.lr_scheduler.ExponentialLR(opt, gamma=0.5)  # does nothing for now

            trainer = EncoderTrainer(
                args, enc, train_dataloader, opt,
                scheduler, action_predictor=action_pred
            )
            trainer.train_symmetryreg()
        elif args.model_name == 'autoencoder':
            dec = get_decoder(args)
            dec.to(device)
            opt = optim.Adam(
                list(enc.parameters()) + list(dec.parameters()),
                lr=args.learning_rate, weight_decay=args.weight_decay
            )
            scheduler = optim.lr_scheduler.ExponentialLR(opt, gamma=0.5)  # does nothing for now

            trainer = EncoderTrainer(args, enc, train_dataloader, opt, scheduler, dec)
            trainer.train_autoencoder()
        elif args.model_name == 'drlim':
            opt = optim.Adam(enc.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
            scheduler = optim.lr_scheduler.ExponentialLR(opt, gamma=0.5)  # does nothing for now

            trainer = EncoderTrainer(
                args, enc, train_dataloader, opt, scheduler
            )
            trainer.train_DRLIM()
        eval = EvalModule(args, enc)
        eval.visualize(axis='x')
        eval.visualize(axis='y')
        eval.visualize(axis='z')
    else:
        enc.load_state_dict(torch.load(
            os.path.join(args.checkpoint_dir, 'model_final.tar')
        ))
        print('Encoder Checkpoint Loaded for evaluation ....')
        eval = EvalModule(args, enc)
        eval.visualize(axis='x')
        eval.visualize(axis='y')
        eval.visualize(axis='z')
