import argparse
import os
import time
import copy

import torch
from torch.utils.data import DataLoader
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
# import torch_optimizer

from datasets.sncore_4k import ShapeNetCore4k
from datasets.sncore_splits import *
from loss import ULIPContrastiveLoss, MultimodalSupConLoss
from models.model import UniL
from utils.common import set_random_seed, format_time, cosine_scheduler, init_np_seed

# device = 'cuda'
device = 'mps'


def get_args():
    # Synth to Synth Benchmark
    # Pretraining settings
    # ShapeNetCore train_dataset length: SN1 -> 7342, SN2 -> 16381, SN3 -> 11581
    root = '/root/data/3D_OS-main/3D_OS_release_data'
    save_path = 'model_logs/pretrain_SN1.pt'
    parser = argparse.ArgumentParser(description='Pretrain with CLIP')
    parser.add_argument('--data_root', type=str, default=root)
    parser.add_argument('--src', type=str, default='SN1')
    parser.add_argument('--num_points', type=int, default=2048)
    parser.add_argument('--dim', type=int, default=512)
    parser.add_argument('--batch_size', type=int, default=16)  # 64->8, 128->8
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--epochs', type=int, default=150)
    parser.add_argument('--save_path', type=str, default=save_path)
    args = parser.parse_args()
    args.data_root = os.path.expanduser(args.data_root)
    accepted_src = ["SN1", "SN2", "SN3"]
    assert args.src in accepted_src, f"Chosen class set {args.src} is not correct"
    accepted_src.remove(args.src)
    args.tar1 = accepted_src[0]
    args.tar2 = accepted_src[1]
    return args


def pretrain(args):
    set_random_seed(0)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.autograd.set_detect_anomaly(True)

    # Prepare Data.
    print("===> Creating dataset...")
    train_dataset = ShapeNetCore4k(args.data_root, split='train', class_choice=list(eval(args.src).keys()), pretrain=True)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True)

    # Init Model.
    print("===> Loading CLAD...")
    model = UniL(embed_dim=args.dim).to(device)
    criterion = MultimodalSupConLoss().to(device)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, betas=(0.9, 0.98), eps=1e-8, weight_decay=0.1)
    #scheduler = cosine_scheduler(3e-3, 1e-5, epochs, len(train_dataloader) // 1, warmup_epochs=1, start_warmup_value=1e-6)

    #optimizer = torch_optimizer.Lamb(model.parameters(), lr=0.006, weight_decay=1e-4)
    #scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2*len(train_dataloader), T_mult=1, eta_min=max(1e-2*1e-3, 1e-6), last_epoch=-1)

    # Training Loop.
    print("===> Start training...")
    model.train()
    lowest_loss = 5.
    best_model = model
    for epoch in range(1, args.epochs+1):
        total_loss = 0.
        start = time.time()
        for step, (pc, text, depth, label) in enumerate(train_dataloader):
            optimizer.zero_grad()
            pc = pc.to(device)
            text = text.to(device)
            depth = depth.to(device)
            label = label.to(device)

            pc_text, pc_image = model(pc, text, depth)
            loss_dict = criterion(pc_text, pc_image, label)
            loss, pc_text_acc, pc_img_acc = loss_dict['loss'], loss_dict['pc_text_acc'], loss_dict['pc_img_acc']
            total_loss += loss.item()
            print(f"pc_text_acc:{pc_text_acc:.2f}, pc_img_acc:{pc_img_acc:.2f}")

            loss.backward()
            optimizer.step()

        end = time.time()
        avg_loss = total_loss/len(train_dataloader)
        if avg_loss < lowest_loss:
            lowest_loss = avg_loss
            best_model = copy.deepcopy(model)
        print(f"EPOCH {epoch} avg loss:{avg_loss}, time:{format_time(end-start)}")

    # Save Model.
    print("===> Saving the best model...")
    save_state = {}
    for param_tensor in best_model.state_dict():
        if 'visual_encoder' not in param_tensor:
            save_state.update({param_tensor: model.state_dict()[param_tensor]})
    torch.save(save_state, args.save_path)
    print("===> Saving complete!")


if __name__ == '__main__':
    args = get_args()
    pretrain(args)
