import argparse
import os
import warnings
import time
import copy

import torch
from torch.utils.data import DataLoader, Dataset
from torch import optim
from torchvision import transforms

from datasets.modelnet import ModelNet40_OOD, H5_Dataset
from datasets.scanobject import ScanObject
from loss import MultimodalSupConLoss
from models.model import UniL
from utils.data_utils import *
from utils.common import set_random_seed, format_time, init_np_seed

device = 'mps'


def get_args():
    # Real to Real Benchmark
    # Pretraining settings
    # ScanObject train_dataset length: SR12 -> 1643, SR13 -> 1686, SR23 -> 1289
    root = '/root/data/3D_OS-main/3D_OS_release_data'
    save_path = 'model_logs/pretrain_SR12.pt'
    parser = argparse.ArgumentParser(description='Pretrain with CLIP')
    parser.add_argument('--data_root', type=str, default=root)
    parser.add_argument('--src', type=str, default='SR12')
    parser.add_argument('--num_points', type=int, default=2048)
    parser.add_argument('--dim', type=int, default=512)
    parser.add_argument('--batch_size', type=int, default=8)  # 64->8, 128->8
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--save_path', type=str, default=save_path)
    args = parser.parse_args()
    args.data_root = os.path.expanduser(args.data_root)
    return args


def pretrain(args):
    set_random_seed(0)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.autograd.set_detect_anomaly(True)

    # Prepare Data.
    print("===> Creating dataset...")
    train_dataset = ScanObject(args.data_root, num_points=args.num_points, class_choice=args.src, pretrain=True)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True)

    # Init Model.
    print("===> Loading CLAD...")
    model = UniL(embed_dim=args.dim).to(device)
    criterion = MultimodalSupConLoss().to(device)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, betas=(0.9, 0.98), eps=1e-8, weight_decay=0.1)

    # Training Loop.
    print("===> Start training...")
    model.train()
    lowest_loss = 5.
    best_model = model
    for epoch in range(1, args.epochs+1):
        total_loss = 0.
        start = time.time()
        for step, (pc, text, depth, label) in enumerate(train_dataloader):
            optimizer.zero_grad()
            pc = pc.to(device)
            text = text.to(device)
            depth = depth.to(device)
            label = label.to(device)

            pc_text, pc_image = model(pc, text, depth)
            loss_dict = criterion(pc_text, pc_image, label)
            loss, pc_text_acc, pc_img_acc = loss_dict['loss'], loss_dict['pc_text_acc'], loss_dict['pc_img_acc']
            total_loss += loss.item()
            print(f"pc_text_acc:{pc_text_acc:.2f}, pc_img_acc:{pc_img_acc:.2f}")

            loss.backward()
            optimizer.step()

        end = time.time()
        avg_loss = total_loss/len(train_dataloader)
        if avg_loss < lowest_loss:
            lowest_loss = avg_loss
            best_model = copy.deepcopy(model)
        print(f"EPOCH {epoch} avg loss:{avg_loss}, time:{format_time(end-start)}")

    # Save Model.
    print("===> Saving the best model...")
    save_state = {}
    for param_tensor in best_model.state_dict():
        if 'visual_encoder' not in param_tensor:
            save_state.update({param_tensor: model.state_dict()[param_tensor]})
    torch.save(save_state, args.save_path)
    print("===> Saving complete!")


if __name__ == '__main__':
    args = get_args()
    pretrain(args)
