import argparse
import os
import time
import copy

import torch
from torch.utils.data import DataLoader
from torch import optim
from torchvision import transforms

from models.foldingnet import SkipVariationalFoldingNet
from datasets.sncore_4k import ShapeNetCore4k
from datasets.sncore_splits import *
from loss import ChamferDistance, KLDivergence
from utils.data_utils import *
from utils.common import set_random_seed, format_time, init_np_seed, save_results

from infer import eval_anomaly

# device = 'cuda'
device = 'mps'


def get_args():
    # Synth to Synth Benchmark
    # Training settings
    # Seeds used for our paper experiments are: 1 -> SN1, 41 -> SN2, 13718 -> SN3.
    # ShapeNetCore train_dataset length: SN1 -> 7342, SN2 -> 16381, SN3 -> 11581
    root = '/root/data/3D_OS-main/3D_OS_release_data'
    pretrained_path = 'model_logs/pretrain_SN1.pt'
    save_path = 'model_logs/vae_SN1.pt'
    save_path1 = 'model_logs/vae_SN1.pt'
    parser = argparse.ArgumentParser(description='Train VAE')
    parser.add_argument('--data_root', type=str, default=root)
    parser.add_argument('--pretrained_path', type=str, default=pretrained_path)
    parser.add_argument('--src', type=str, default='SN1', choices=['SN1', 'SN2', 'SN3'])
    parser.add_argument('--augment_set', type=str, default='st', choices=['st', 'rw'])
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--num_points', type=int, default=2048)
    parser.add_argument('--dim', type=int, default=512)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_workers', type=int, default=10)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--save_path', type=str, default=save_path)
    parser.add_argument('--save_path1', type=str, default=save_path)
    args = parser.parse_args()
    args.data_root = os.path.expanduser(args.data_root)
    accepted_src = ["SN1", "SN2", "SN3"]
    assert args.src in accepted_src, f"Chosen class set {args.src} is not correct"
    accepted_src.remove(args.src)
    args.tar1 = accepted_src[0]
    args.tar2 = accepted_src[1]
    return args


def get_sncore_train_loader(opt, split="train"):
    if opt.augment_set == 'st':
        print("Augmentation set ST")
        set_transforms = [
            PointCloudToTensor(),
            RandomSample(opt.num_points),
            AugmentScale(low=2/3, high=3/2),
            AugmentTranslate(translate_range=0.2)]
    elif opt.augment_set == 'rw':
        print("Augmentation set RW")
        set_transforms = [
            PointCloudToTensor(),
            RandomSample(opt.num_points),
            AugmentScale(),
            AugmentRotate(axis=[0.0, 1.0, 0.0]),
            AugmentRotatePerturbation(),
            AugmentTranslate(),
            AugmentJitter()]
    else:
        raise ValueError(f"Unknown augmentation set: {opt.augm_set}")

    train_transforms = transforms.Compose(set_transforms)
    train_data = ShapeNetCore4k(data_root=opt.data_root, split=split, class_choice=list(eval(opt.src).keys()),
                                num_points=opt.num_points, transforms=train_transforms, pretrain=False)
    train_loader = DataLoader(train_data, batch_size=opt.batch_size, drop_last=True, num_workers=opt.num_workers,
                              worker_init_fn=init_np_seed)
    return train_loader


def get_sncore_val_loader(opt):
    base_data_params = {'data_root': opt.data_root, 'split': "val", 'num_points': opt.num_points, 'transforms': None}
    val_data = ShapeNetCore4k(**base_data_params, class_choice=list(eval(opt.src).keys()), pretrain=False)
    val_loader = DataLoader(val_data, batch_size=opt.batch_size, drop_last=False, num_workers=opt.num_workers,
                            worker_init_fn=init_np_seed)
    return val_loader


@torch.no_grad()
def val_model(model, criterion, val_loader, val_data_length):
    total_loss = 0.
    model.eval()
    for pc in val_loader:
        pc = pc.to(device)
        folding2, _, _, _ = model(pc)
        loss = criterion(folding2, pc)
        total_loss += loss.item()
    return total_loss / val_data_length


def train(args):
    set_random_seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.use_deterministic_algorithms(True)
    torch.autograd.set_detect_anomaly(True)

    # Prepare data.
    print("===> Creating training dataset...")
    train_dataloader = get_sncore_train_loader(args, split='train')
    train_data_length = len(train_dataloader)
    print("===> Creating validating dataset...")
    val_dataloader = get_sncore_val_loader(args)
    val_data_length = len(val_dataloader)

    # Load pretrained model's encoder.
    print("===> Loading pretrained encoder...")
    model = SkipVariationalFoldingNet(args.dim, args.num_points).to(device)
    pretrained = torch.load(args.pretrained_path, map_location=device)
    pretrained = {param_name.replace('point_encoder', 'encoder'): param for param_name, param in pretrained.items()}
    model.load_state_dict(pretrained, strict=False)

    # Loss function and optimizer.
    criterion_cd = ChamferDistance().to(device)
    criterion_kld = KLDivergence().to(device)
    optimizer = optim.Adam(model.parameters(), 0.001, (0.9, 0.999), weight_decay=1e-6)

    # Training Loop.
    print("===> Start training...")
    model.train()
    lowest_val_loss = 1e6
    highest_auroc = 0.0
    best_model = model
    for epoch in range(1, args.epochs+1):
        print("===> Epoch {}/{}".format(epoch, args.epochs))
        total_loss = 0.
        start = time.time()
        for pc in train_dataloader:
            optimizer.zero_grad()
            pc = pc.to(device)
            folding2, folding1, mu, log_sigma = model(pc)
            _, _, fake_mu, fake_log_sigma = model(folding2)

            fold2_loss = criterion_cd(pc, folding2)
            fold1_loss = criterion_cd(pc, folding1)
            kld_loss = criterion_kld(mu, log_sigma)
            fake_kld_loss = criterion_kld(fake_mu, fake_log_sigma)
            loss = fold2_loss + fold1_loss + kld_loss + fake_kld_loss
            total_loss += loss.item()
            print("inner_loss:{:.6f}, outer_loss:{:.6f}, kld_loss:{:.6f}".format(fold1_loss, fold2_loss, kld_loss))

            loss.backward()
            optimizer.step()
        end = time.time()
        avg_loss = total_loss / train_data_length
        print(f"EPOCH {epoch} avg loss:{avg_loss}, time:{format_time(end - start)}")

        # Validate and log model.
        print("===> Validating model...")
        t1 = time.time()
        val_loss = val_model(model, criterion_cd, val_dataloader, val_data_length)
        elapsed = time.time() - t1
        if val_loss < lowest_val_loss:
            lowest_val_loss = val_loss
            best_model = copy.deepcopy(model)
        print(f"EPOCH {epoch} val loss:{val_loss}, val time:{format_time(elapsed)}")
    print("===> Training complete!")

    # Save Model.
    print("===> Saving the best model...")
    torch.save(model.state_dict(), args.save_path)
    torch.save(best_model.state_dict(), args.save_path)
    print("===> Saving complete!")

    # Evaluate Model
    print("===> Evaluating model...")
    res_tar1, res_tar2, res_big_tar = eval_anomaly(args, model)
    save_results(res_tar1, res_tar2, res_big_tar, args.src, args.save_path)


if __name__ == '__main__':
    args = get_args()
    train(args)
