import argparse
import os
import warnings
import time
import copy

import torch
from torch.utils.data import DataLoader, Dataset
from torch import optim
from torchvision import transforms

from datasets.modelnet import ModelNet40_OOD, H5_Dataset
from datasets.scanobject import ScanObject
from loss import MultimodalSupConLoss
from models.model import UniL
from utils.data_utils import *
from utils.common import set_random_seed, format_time

device = 'mps'


def get_args():
    # Synth to Real Benchmark
    # Pretraining settings
    # ModelNet train_dataset length: SR1 -> 2378, SR2 -> 1916
    # ModelNet-C train_dataset length: SR1 -> 19020, SR2 -> 15325
    root = '/root/data/3D_OS-main/3D_OS_release_data'
    save_path = 'model_logs/pretrain_SR1.pt'
    parser = argparse.ArgumentParser(description='Pretrain with CLIP')
    parser.add_argument('--data_root', type=str, default=root)
    parser.add_argument('--src', type=str, default='SR1')
    parser.add_argument('--num_points', type=int, default=2048)
    parser.add_argument('--dim', type=int, default=512)
    parser.add_argument('--batch_size', type=int, default=8)  # 64->8, 128->8
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--save_path', type=str, default=save_path)
    parser.add_argument("--corruption", type=str, default=None)
    parser.add_argument('--use_SONN', type=bool, default=True)
    args = parser.parse_args()
    args.data_root = os.path.expanduser(args.data_root)
    return args


def get_corrupted_data_list(opt, severity=None, split="train"):
    assert split in ['train', 'test']

    if opt.src == "SR1":
        prefix = "modelnet_set1"
    elif opt.src == "SR2":
        prefix = "modelnet_set2"
    else:
        raise ValueError(f"Expected SR source but received: {opt.src} ")

    print(f"get_list_corr_data for {prefix} - split {split}")

    # loads corrupted data
    if severity is None:
        severity = [1, 2, 3, 4, 5]
    if opt.corruption == 'lidar' or opt.corruption == 'occlusion':
        print(f"loading {opt.corruption} data")
        root = os.path.join(opt.data_root, "ModelNet40_corrupted", opt.corruption)
        file_names = [f"{root}/{prefix}_{split}_{opt.corruption}_sev" + str(i) + ".h5" for i in severity]
        print(f"corr list files: {file_names}\n")
    elif opt.corruption == 'all':
        print("loading both lidar and occlusion data")
        file_names = []
        root_lidar = os.path.join(opt.data_root, "ModelNet40_corrupted", "lidar")
        file_names.extend([f"{root_lidar}/{prefix}_{split}_lidar_sev" + str(i) + ".h5" for i in severity])
        root_occ = os.path.join(opt.data_root, "ModelNet40_corrupted", "occlusion")
        file_names.extend([f"{root_occ}/{prefix}_{split}_occlusion_sev" + str(i) + ".h5" for i in severity])
        print(f"corr list files: {file_names}\n")
    else:
        raise ValueError(f"Unknown corruption specified: {opt.corruption}")

    corrupted_datasets = []
    for h5_path in file_names:
        corrupted_datasets.append(H5_Dataset(h5_file=h5_path, num_points=opt.num_points, class_choice=opt.src,
                                             transforms=None, pretrain=True))

    return corrupted_datasets


def get_sonn_data(args):
    print("loading sonn data")
    sonn_args = {
        'data_root': args.data_root,
        'sonn_split': "main_split",
        'h5_file': "objectdataset.h5",
        'split': 'train',  # we use both training (unused) and test samples during evaluation
        'num_points': args.num_points,  # default: use all 2048 points to avoid sampling randomness
        'transforms': None  # no augmentation applied at inference time
    }
    sonn_datasets = []
    if args.src == 'SR1':
        print("Src is SR1\n")
        sonn_datasets.append(ScanObject(class_choice="sonn_2_mdSet1", **sonn_args))
    elif args.src == 'SR2':
        print("Src is SR2\n")
        sonn_datasets.append(ScanObject(class_choice="sonn_2_mdSet2", **sonn_args))
    else:
        raise ValueError(f"SONN data for pretrain - wrong src: {args.src}")
    return sonn_datasets


def pretrain(args):
    set_random_seed(0)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.autograd.set_detect_anomaly(True)

    # Prepare Data.
    print("===> Creating dataset...")
    train_dataset = ModelNet40_OOD(args.data_root, num_points=args.num_points, train=True, class_choice=args.src, pretrain=True)
    if args.corruption is not None:
        l_corr_data = get_corrupted_data_list(args)
        assert isinstance(l_corr_data, list)
        assert isinstance(l_corr_data[0], Dataset)
        l_corr_data.append(train_dataset)
        train_dataset = torch.utils.data.ConcatDataset(l_corr_data)
        print(f"{args.src} + corruption {args.corruption} - train data len: {len(train_dataset)}")
    if args.use_SONN:
        sonn_data = get_sonn_data(args)
        assert isinstance(sonn_data, list)
        assert isinstance(sonn_data[0], Dataset)
        sonn_data.append(train_dataset)
        train_dataset = torch.utils.data.ConcatDataset(sonn_data)
        print(f"{args.src} + sonn - train data len: {len(train_dataset)}")
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True)

    # Init Model.
    print("===> Loading CLAD...")
    model = UniL(embed_dim=args.dim).to(device)
    criterion = MultimodalSupConLoss().to(device)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, betas=(0.9, 0.98), eps=1e-8, weight_decay=0.1)

    # Training Loop.
    print("===> Start training...")
    model.train()
    lowest_loss = 5.
    best_model = model
    for epoch in range(1, args.epochs+1):
        total_loss = 0.
        start = time.time()
        for step, (pc, text, depth, label) in enumerate(train_dataloader):
            optimizer.zero_grad()
            pc = pc.to(device)
            text = text.to(device)
            depth = depth.to(device)
            label = label.to(device)

            pc_text, pc_image = model(pc, text, depth)
            loss_dict = criterion(pc_text, pc_image, label)
            loss, pc_text_acc, pc_img_acc = loss_dict['loss'], loss_dict['pc_text_acc'], loss_dict['pc_img_acc']
            total_loss += loss.item()
            print(f"pc_text_acc:{pc_text_acc:.2f}, pc_img_acc:{pc_img_acc:.2f}")

            loss.backward()
            optimizer.step()

        end = time.time()
        avg_loss = total_loss/len(train_dataloader)
        if avg_loss < lowest_loss:
            lowest_loss = avg_loss
            best_model = copy.deepcopy(model)
        print(f"EPOCH {epoch} avg loss:{avg_loss}, time:{format_time(end-start)}")

    # Save Model.
    print("===> Saving the best model...")
    save_state = {}
    for param_tensor in best_model.state_dict():
        if 'visual_encoder' not in param_tensor:
            save_state.update({param_tensor: model.state_dict()[param_tensor]})
    torch.save(save_state, args.save_path)
    print("===> Saving complete!")


if __name__ == '__main__':
    args = get_args()
    pretrain(args)
