import argparse
import os
import warnings
import time
import copy

from torch.utils import data
from torch.utils.data import DataLoader
from torch import optim
from torchvision import transforms

from models.foldingnet import SkipVariationalFoldingNet
from datasets.modelnet import ModelNet40_OOD, H5_Dataset
from loss import ChamferDistance, KLDivergence
from utils.data_utils import *
from utils.common import set_random_seed, format_time, init_np_seed, save_results

from infer import eval_anomaly

device = 'mps'


def get_args():
    # Synth to Real Benchmark
    # Training settings
    # Seeds used for our paper experiments are: 1 -> SR1, 41 -> SR2
    # ModelNet train_dataset length: SR1 -> 2378, SR2 -> 1916
    # ModelNet-C train_dataset length: SR1 -> 19020, SR2 -> 15325
    root = '/root/data/3D_OS-main/3D_OS_release_data'
    pretrained_path = 'model_logs/pretrain_SR1.pt'
    save_path = 'model_logs/vae_SR1.pt'
    save_path1 = 'model_logs/vae_SR1_best.pt'
    parser = argparse.ArgumentParser(description='Train VAE')
    parser.add_argument('--data_root', type=str, default=root)
    parser.add_argument('--pretrained_path', type=str, default=pretrained_path)
    parser.add_argument('--src', type=str, default='SR2', choices=['SR1', 'SR2'])
    parser.add_argument("--augment_set", type=str, default="rw", choices=["st", "rw"])
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--num_points', type=int, default=2048)
    parser.add_argument('--dim', type=int, default=512)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--save_path', type=str, default=save_path)
    # Adopt Corrupted data
    # this flag should be set also during evaluation if testing Synth->Real Corr/LIDAR Augmented models
    parser.add_argument("--corruption", type=str, default='all')
    parser.add_argument('--script_mode', type=str, default='train')
    args = parser.parse_args()
    args.data_root = os.path.expanduser(args.data_root)
    accepted_src = ['SR1', 'SR2']
    assert args.src in accepted_src, f"Chosen class set {args.src} is not correct"
    return args


def get_corrupted_data_list(opt, severity=None, split="train"):
    assert split in ['train', 'test']

    if opt.src == "SR1":
        prefix = "modelnet_set1"
    elif opt.src == "SR2":
        prefix = "modelnet_set2"
    else:
        raise ValueError(f"Expected SR source but received: {opt.src} ")

    print(f"get_list_corr_data for {prefix} - split {split}")

    # loads corrupted data
    if severity is None:
        severity = [1, 2, 3, 4, 5]
    if opt.corruption == 'lidar' or opt.corruption == 'occlusion':
        print(f"loading {opt.corruption} data")
        root = os.path.join(opt.data_root, "ModelNet40_corrupted", opt.corruption)
        file_names = [f"{root}/{prefix}_{split}_{opt.corruption}_sev" + str(i) + ".h5" for i in severity]
        print(f"corr list files: {file_names}\n")
    elif opt.corruption == 'all':
        print("loading both lidar and occlusion data")
        file_names = []
        root_lidar = os.path.join(opt.data_root, "ModelNet40_corrupted", "lidar")
        file_names.extend([f"{root_lidar}/{prefix}_{split}_lidar_sev" + str(i) + ".h5" for i in severity])
        root_occ = os.path.join(opt.data_root, "ModelNet40_corrupted", "occlusion")
        file_names.extend([f"{root_occ}/{prefix}_{split}_occlusion_sev" + str(i) + ".h5" for i in severity])
        print(f"corr list files: {file_names}\n")
    else:
        raise ValueError(f"Unknown corruption specified: {opt.corruption}")

    # augmentation mgmt
    if opt.script_mode.startswith("eval") or split == 'test':
        augment_set = None
    else:
        # synth -> real augmentation
        warnings.warn(f"Using RW augmentation set for corrupted data")
        augment_set = transforms.Compose([
            PointCloudToTensor(),
            AugmentScale(),
            AugmentRotate(axis=[0.0, 1.0, 0.0]),
            AugmentRotatePerturbation(),
            AugmentTranslate(),
            AugmentJitter()
        ])

    corrupted_datasets = []
    for h5_path in file_names:
        corrupted_datasets.append(H5_Dataset(h5_file=h5_path, num_points=opt.num_points, class_choice=opt.src,
                                             transforms=augment_set, pretrain=False))

    return corrupted_datasets


def get_md_train_loader(opt):
    if opt.augment_set == 'st':
        print("Augmentation set ST")
        set_transforms = [
            PointCloudToTensor(),
            RandomSample(opt.num_points),
            AugmentScale(low=2/3, high=3/2),
            AugmentTranslate(translate_range=0.2)]
    elif opt.augment_set == 'rw':
        # transformation used for Synthetic->Real-World
        print("Augmentation set RW")
        set_transforms = [
            PointCloudToTensor(),
            RandomSample(opt.num_points),
            AugmentScale(),
            AugmentRotate(axis=[0.0, 1.0, 0.0]),
            AugmentRotatePerturbation(),
            AugmentTranslate(),
            AugmentJitter()]
    else:
        raise ValueError(f"Unknown augmentation set: {opt.augm_set}")

    train_transforms = transforms.Compose(set_transforms)
    train_dataset = ModelNet40_OOD(data_root=opt.data_root, num_points=opt.num_points, train=True, class_choice=opt.src,
                                   transforms=train_transforms, pretrain=False)

    if opt.corruption is not None:
        # load corrupted datasets
        assert opt.augment_set == 'rw'
        l_corr_data = get_corrupted_data_list(opt)
        assert isinstance(l_corr_data, list)
        assert isinstance(l_corr_data[0], data.Dataset)
        l_corr_data.append(train_dataset)
        train_dataset = torch.utils.data.ConcatDataset(l_corr_data)
        print(f"{opt.src} + corruption {opt.corruption} - train data len: {len(train_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, drop_last=True, num_workers=opt.num_workers,
                              worker_init_fn=init_np_seed)
    return train_loader


def get_md_val_loader(opt):
    base_data_params = {'data_root': opt.data_root, 'num_points': opt.num_points, 'transforms': None, 'train': False}
    val_dataset = ModelNet40_OOD(**base_data_params, class_choice=opt.src, pretrain=False)
    if opt.corruption is not None:
        # load corrupted datasets
        assert opt.augment_set == 'rw'
        l_corr_data = get_corrupted_data_list(opt, split='test')
        assert isinstance(l_corr_data, list)
        assert isinstance(l_corr_data[0], data.Dataset)
        l_corr_data.append(val_dataset)
        val_dataset = torch.utils.data.ConcatDataset(l_corr_data)
        print(f"{opt.src} + corruption {opt.corruption} - val data len: {len(val_dataset)}")
    # note: modelnet synthetic are not used in synth->real eval
    val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, drop_last=False, num_workers=opt.num_workers,
                            worker_init_fn=init_np_seed)
    return val_loader


@torch.no_grad()
def val_model(model, criterion, val_loader, val_data_length):
    total_loss = 0.
    model.eval()
    for pc in val_loader:
        pc = pc.to(device)
        folding2, _, _, _ = model(pc)
        loss = criterion(folding2, pc)
        total_loss += loss.item()
    return total_loss / val_data_length


def train(args):
    set_random_seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)
    torch.autograd.set_detect_anomaly(True)

    # Prepare data.
    print("===> Creating training dataset...")
    train_dataloader = get_md_train_loader(args)
    train_data_length = len(train_dataloader)
    print("===> Creating validating dataset...")
    val_dataloader = get_md_val_loader(args)
    val_data_length = len(val_dataloader)

    # Load pretrained model's encoder.
    print("===> Loading pretrained encoder...")
    model = SkipVariationalFoldingNet(args.dim, args.num_points).to(device)
    pretrained = torch.load(args.pretrained_path, map_location=device)
    pretrained = {param_name.replace('point_encoder', 'encoder'): param for param_name, param in pretrained.items()}
    model.load_state_dict(pretrained, strict=False)
    print("===> Loading complete!")
    for name, param in model.encoder.named_parameters():
        param.requires_grad = False

    # Loss function and optimizer.
    criterion_cd = ChamferDistance().to(device)
    criterion_kld = KLDivergence().to(device)
    optimizer = optim.Adam(model.parameters(), 0.001, (0.9, 0.999), weight_decay=1e-6)

    # Training Loop.
    print("===> Start training...")
    model.train()
    lowest_val_loss = 1e6
    best_model = model
    for epoch in range(1, args.epochs+1):
        print("===> Epoch {}/{}".format(epoch, args.epochs))
        total_loss = 0.
        start = time.time()
        for pc in train_dataloader:
            optimizer.zero_grad()
            pc = pc.to(device)
            folding2, folding1, mu, log_sigma = model(pc)
            _, _, fake_mu, fake_log_sigma = model(folding2)

            fold2_loss = criterion_cd(pc, folding2)
            fold1_loss = criterion_cd(pc, folding1)
            kld_loss = criterion_kld(mu, log_sigma)
            fake_kld_loss = criterion_kld(fake_mu, fake_log_sigma)
            loss = fold2_loss + fold1_loss + kld_loss + fake_kld_loss
            total_loss += loss.item()
            print("inner_loss:{:.6f}, outer_loss:{:.6f}, kld_loss:{:.6f}".format(fold1_loss, fold2_loss, kld_loss))

            loss.backward()
            optimizer.step()
        end = time.time()
        avg_loss = total_loss / train_data_length
        print(f"EPOCH {epoch} avg loss:{avg_loss}, time:{format_time(end - start)}")

        # Validate and log model.
        print("===> Validating model...")
        t1 = time.time()
        val_loss = val_model(model, criterion_cd, val_dataloader, val_data_length)
        elapsed = time.time() - t1
        if val_loss < lowest_val_loss:
            lowest_val_loss = val_loss
            best_model = copy.deepcopy(model)
        print(f"EPOCH {epoch} val loss:{val_loss}, val time:{format_time(elapsed)}")
    print("===> Training complete!")

    # Save Model.
    print("===> Saving the best model...")
    torch.save(model.state_dict(), args.save_path)
    torch.save(best_model.state_dict(), args.save_path1)
    print("===> Saving complete!")

    # Evaluate Model
    print("===> Evaluating model...")
    res_tar1, res_tar2, res_big_tar = eval_anomaly(args, model)
    save_results(res_tar1, res_tar2, res_big_tar, args.src, args.save_path)


if __name__ == '__main__':
    args = get_args()
    #train(args)

    from utils.common import show_point_cloud
    args.corruption = 'occlusion'
    args.script_mode = 'eval'
    #val_loader = get_md_val_loader(args)
    l_corr_data = get_corrupted_data_list(args, severity=[5])
    train_dataset = torch.utils.data.ConcatDataset(l_corr_data)
    print(len(train_dataset))
    pc = train_dataset[100]
    show_point_cloud(pc.transpose(1, 0))

    # point, label = train_dataset[718]
    # point = random_sample(point, 2048)
    # print(point.shape)
