import argparse
import os
import warnings
import time
import copy

from sklearn.model_selection import train_test_split
from torch.utils import data
from torch.utils.data import DataLoader, Subset
from torch import optim
from torchvision import transforms

from datasets.scanobject import ScanObject
from models.foldingnet import SkipVariationalFoldingNet
from loss import ChamferDistance, KLDivergence
from utils.data_utils import *
from utils.common import set_random_seed, format_time, init_np_seed, save_results

from infer import eval_real2real

device = 'mps'


def get_args():
    # Real to Real Benchmark
    # Training settings
    # Seeds used for our paper experiments are: 1 -> SR12, 41 -> SR13, 13718 -> SR23
    # ScanObject train_dataset length: SR12 -> 1643, SR13 -> 1686, SR23 -> 1289
    root = '/root/data/3D_OS-main/3D_OS_release_data'
    pretrained_path = 'model_logs/pretrain_SR12.pt'
    save_path = 'model_logs/vae_SR12.pt'
    save_path1 = 'model_logs/vae_SR12_best.pt'
    parser = argparse.ArgumentParser(description='Train VAE')
    parser.add_argument('--data_root', type=str, default=root)
    parser.add_argument('--pretrained_path', type=str, default=pretrained_path)
    parser.add_argument('--src', type=str, default='SR12', choices=['SR12', 'SR13', 'SR23'])
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--num_points', type=int, default=2048)
    parser.add_argument('--dim', type=int, default=512)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--save_path', type=str, default=save_path)
    parser.add_argument('--save_path1', type=str, default=save_path1)
    args = parser.parse_args()
    args.data_root = os.path.expanduser(args.data_root)
    accepted_src = ['SR12', 'SR13', 'SR23']
    assert args.src in accepted_src, f"Chosen class set {args.src} is not correct"
    return args


def get_sonn_loaders(opt):
    # transformation used for Real->Real-World
    augment_set = transforms.Compose([
            PointCloudToTensor(),
            RandomSample(opt.num_points),  # sampling as a data augmentation
            AugmentScale(),
            AugmentRotate(axis=[0.0, 1.0, 0.0]),
            AugmentRotatePerturbation(),
            AugmentTranslate(),
            AugmentJitter()
        ])
    print("Train - transforms: ", augment_set)

    whole_train_data = ScanObject(opt.data_root, num_points=2048, split='train', class_choice=opt.src, transforms=augment_set, pretrain=False)

    # split whole train into train and val (deterministic)
    num_val = int(len(whole_train_data) * 10 / 100)
    train_idx, val_idx = train_test_split(list(range(len(whole_train_data))), test_size=num_val, shuffle=True, random_state=42)
    train_data = Subset(whole_train_data, train_idx)
    val_data = Subset(whole_train_data, val_idx)  # val_data is augmented as train_data during training
    print(f"Train - src: {opt.src} - train_data: {len(train_data)}, val_data: {len(val_data)}")

    train_loader = DataLoader(train_data, batch_size=opt.batch_size, drop_last=True, num_workers=opt.num_workers, worker_init_fn=init_np_seed, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=opt.batch_size, drop_last=True, num_workers=opt.num_workers, worker_init_fn=init_np_seed, shuffle=True)

    return train_loader, val_loader


@torch.no_grad()
def val_model(model, criterion, val_loader, val_data_length):
    total_loss = 0.
    model.eval()
    for pc in val_loader:
        pc = pc.to(device)
        folding2, _, _, _ = model(pc)
        loss = criterion(folding2, pc)
        total_loss += loss.item()
    return total_loss / val_data_length


def train(args):
    set_random_seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.use_deterministic_algorithms(True)
    torch.autograd.set_detect_anomaly(True)

    # Prepare data.
    print("===> Creating dataset...")
    train_dataloader, val_dataloader = get_sonn_loaders(args)
    train_data_length, val_data_length = len(train_dataloader), len(val_dataloader)

    # Load pretrained model's encoder.
    print("===> Loading pretrained encoder...")
    model = SkipVariationalFoldingNet(args.dim, args.num_points).to(device)
    pretrained = torch.load(args.pretrained_path, map_location=device)
    pretrained = {param_name.replace('point_encoder', 'encoder'): param for param_name, param in pretrained.items()}
    model.load_state_dict(pretrained, strict=False)
    print("===> Loading complete!")
    # for name, param in model.encoder.named_parameters():
    #     param.requires_grad = False

    # Loss function and optimizer.
    criterion_cd = ChamferDistance().to(device)
    criterion_kld = KLDivergence().to(device)
    optimizer = optim.Adam(model.parameters(), 0.001, (0.9, 0.999), weight_decay=1e-6)

    # Training Loop.
    print("===> Start training...")
    model.train()
    lowest_val_loss = 1e6
    best_model = model
    for epoch in range(1, args.epochs+1):
        print("===> Epoch {}/{}".format(epoch, args.epochs))
        total_loss = 0.
        start = time.time()
        for pc in train_dataloader:
            optimizer.zero_grad()
            pc = pc.to(device)
            folding2, folding1, mu, log_sigma = model(pc)
            _, _, fake_mu, fake_log_sigma = model(folding2)

            fold2_loss = criterion_cd(pc, folding2)
            fold1_loss = criterion_cd(pc, folding1)
            kld_loss = criterion_kld(mu, log_sigma)
            fake_kld_loss = criterion_kld(fake_mu, fake_log_sigma)
            loss = fold2_loss + fold1_loss + kld_loss + fake_kld_loss
            total_loss += loss.item()
            print("inner_loss:{:.6f}, outer_loss:{:.6f}, kld_loss:{:.6f}".format(fold1_loss, fold2_loss, kld_loss))

            loss.backward()
            optimizer.step()
        end = time.time()
        avg_loss = total_loss / train_data_length
        print(f"EPOCH {epoch} avg loss:{avg_loss}, time:{format_time(end - start)}")

        # Validate and log model.
        print("===> Validating model...")
        t1 = time.time()
        val_loss = val_model(model, criterion_cd, val_dataloader, val_data_length)
        elapsed = time.time() - t1
        if val_loss < lowest_val_loss:
            lowest_val_loss = val_loss
            best_model = copy.deepcopy(model)
        print(f"EPOCH {epoch} val loss:{val_loss}, val time:{format_time(elapsed)}")
    print("===> Training complete!")

    # Save Model.
    print("===> Saving the best model...")
    torch.save(model.state_dict(), args.save_path)
    torch.save(best_model.state_dict(), args.save_path1)
    print("===> Saving complete!")

    # Evaluate Model
    print("===> Evaluating model...")
    eval_real2real(args, model)
    eval_real2real(args, best_model)


if __name__ == '__main__':
    args = get_args()
    train(args)
