import os
import argparse

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
import clip
from data.prompt_dataset import PromptDataSet
from data.simple_dataset import SimpleDataSet

from model.Text_IF_model import ControlFusion as create_model
from scripts.utils import read_data, train_one_epoch, evaluate, create_lr_scheduler, merge_and_print_paths
import datetime
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import transforms as T

def main(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    if os.path.exists("./experiments") is False:
        os.makedirs("./experiments")

    file_name = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    filefold_path = "./experiments/TextIF_train_{}".format(file_name)
    os.makedirs(filefold_path)
    file_img_path = os.path.join(filefold_path, "img")
    os.makedirs(file_img_path)
    file_weights_path = os.path.join(filefold_path, "weights")
    os.makedirs(file_weights_path)
    file_log_path = os.path.join(filefold_path, "log")
    os.makedirs(file_log_path)

    tb_writer = SummaryWriter(log_dir=file_log_path)

    best_val_loss = 1e5
    start_epoch = 0

    print("Loading IVF Fusion and noise slight Task!")
    if args.vi_noise_slight_path is not None:
        train_vi_noise_slight_path_list, val_vi_noise_slight_path_list = read_data(args.vi_noise_slight_path)
    else:
        train_vi_noise_slight_path_list = val_vi_noise_slight_path_list = None

    print("Loading IVF Fusion and noise moderate Task!")
    if args.vi_noise_moderate_path is not None:
        train_vi_noise_moderate_path_list, val_vi_noise_moderate_path_list = read_data(args.vi_noise_moderate_path)
    else:
        train_vi_noise_moderate_path_list = val_vi_noise_moderate_path_list = None

    print("Loading IVF Fusion and noise average Task!")
    if args.vi_noise_average_path is not None:
        train_vi_noise_average_path_list, val_vi_noise_average_path_list = read_data(args.vi_noise_average_path)
    else:
        train_vi_noise_average_path_list = val_vi_noise_average_path_list = None

    print("Loading IVF Fusion and noise serious Task!")
    if args.vi_noise_serious_path is not None:
        train_vi_noise_serious_path_list, val_vi_noise_serious_path_list = read_data(args.vi_noise_serious_path)
    else:
        train_vi_noise_serious_path_list = val_vi_noise_serious_path_list = None




    print("Loading IVF Fusion and Over-Exposure slight Task!")
    if args.over_exposure_slight_path is not None:
        train_over_exposure_slight_path_list, val_over_exposure_slight_path_list = read_data(args.over_exposure_slight_path)
    else:
        train_over_exposure_slight_path_list = val_over_exposure_slight_path_list = None

    print("Loading IVF Fusion and Over-Exposure moderate Task!")
    if args.over_exposure_moderate_path is not None:
        train_over_exposure_moderate_path_list, val_over_exposure_moderate_path_list = read_data(args.over_exposure_moderate_path)
    else:
        train_over_exposure_moderate_path_list = val_over_exposure_moderate_path_list = None

    print("Loading IVF Fusion and Over-Exposure average Task!")
    if args.over_exposure_average_path is not None:
        train_over_exposure_average_path_list, val_over_exposure_average_path_list = read_data(args.over_exposure_average_path)
    else:
        train_over_exposure_average_path_list = val_over_exposure_average_path_list = None

    print("Loading IVF Fusion and Over-Exposure serious Task!")
    if args.over_exposure_serious_path is not None:
        train_over_exposure_serious_path_list, val_over_exposure_serious_path_list = read_data(args.over_exposure_serious_path)
    else:
        train_over_exposure_serious_path_list = val_over_exposure_serious_path_list = None




    print("Loading IVF Fusion and ir_low_contrast slight Task!")
    if args.ir_low_contrast_slight_path is not None:
        train_ir_low_contrast_slight_path_list, val_ir_low_contrast_slight_path_list = read_data(args.ir_low_contrast_slight_path)
    else:
        train_ir_low_contrast_slight_path_list = val_ir_low_contrast_slight_path_list = None

    print("Loading IVF Fusion and ir_low_contrast moderate Task!")
    if args.ir_low_contrast_moderate_path is not None:
        train_ir_low_contrast_moderate_path_list, val_ir_low_contrast_moderate_path_list = read_data(args.ir_low_contrast_moderate_path)
    else:
        train_ir_low_contrast_moderate_path_list = val_ir_low_contrast_moderate_path_list = None

    print("Loading IVF Fusion and ir_low_contrast average Task!")
    if args.ir_low_contrast_average_path is not None:
        train_ir_low_contrast_average_path_list, val_ir_low_contrast_average_path_list = read_data(args.ir_low_contrast_average_path)
    else:
        train_ir_low_contrast_average_path_list = val_ir_low_contrast_average_path_list = None

    print("Loading IVF Fusion and ir_low_contrast serious Task!")
    if args.ir_low_contrast_serious_path is not None:
        train_ir_low_contrast_serious_path_list, val_ir_low_contrast_serious_path_list = read_data(args.ir_low_contrast_serious_path)
    else:
        train_ir_low_contrast_serious_path_list = val_ir_low_contrast_serious_path_list = None



    print("Loading IVF Fusion and ir_noise_slight Task!")
    if args.ir_noise_slight_path is not None:
        train_ir_noise_slight_path_list, val_ir_noise_slight_path_list = read_data(args.ir_noise_slight_path)
    else:
        train_ir_noise_slight_path_list = val_ir_noise_slight_path_list = None

    print("Loading IVF Fusion and ir_noise_moderate Task!")
    if args.ir_noise_moderate_path is not None:
        train_ir_noise_moderate_path_list, val_ir_noise_moderate_path_list = read_data(args.ir_noise_moderate_path)
    else:
        train_ir_noise_moderate_path_list = val_ir_noise_moderate_path_list = None

    print("Loading IVF Fusion and ir_noise_average Task!")
    if args.ir_noise_average_path is not None:
        train_ir_noise_average_path_list, val_ir_noise_average_path_list = read_data(args.ir_noise_average_path)
    else:
        train_ir_noise_average_path_list = val_ir_noise_average_path_list = None

    print("Loading IVF Fusion and ir_noise_serious Task!")
    if args.ir_noise_serious_path is not None:
        train_ir_noise_serious_path_list, val_ir_noise_serious_path_list = read_data(args.ir_noise_serious_path)
    else:
        train_ir_noise_serious_path_list = val_ir_noise_serious_path_list = None




    print("Loading IVF Fusion and ir_stripe_noise_slight Task!")
    if args.ir_stripe_noise_slight_path is not None:
        train_ir_stripe_noise_slight_path_list, val_ir_stripe_noise_slight_path_list = read_data(args.ir_stripe_noise_slight_path)
    else:
        train_ir_stripe_noise_slight_path_list = val_ir_stripe_noise_slight_path_list = None

    print("Loading IVF Fusion and ir_stripe_noise_moderate Task!")
    if args.ir_stripe_noise_moderate_path is not None:
        train_ir_stripe_noise_moderate_path_list, val_ir_stripe_noise_moderate_path_list = read_data(args.ir_stripe_noise_moderate_path)
    else:
        train_ir_stripe_noise_moderate_path_list = val_ir_stripe_noise_moderate_path_list = None

    print("Loading IVF Fusion and ir_stripe_noise_average Task!")
    if args.ir_stripe_noise_average_path is not None:
        train_ir_stripe_noise_average_path_list, val_ir_stripe_noise_average_path_list = read_data(args.ir_stripe_noise_average_path)
    else:
        train_ir_stripe_noise_average_path_list = val_ir_stripe_noise_average_path_list = None

    print("Loading IVF Fusion and ir_stripe_noise_serious Task!")
    if args.ir_stripe_noise_serious_path is not None:
        train_ir_stripe_noise_serious_path_list, val_ir_stripe_noise_serious_path_list = read_data(args.ir_stripe_noise_serious_path)
    else:
        train_ir_stripe_noise_serious_path_list = val_ir_stripe_noise_serious_path_list = None

    # 调用函数并接收返回值
    (
        vi_noise_train, vi_noise_val,
        over_exposure_train, over_exposure_val,
        ir_low_contrast_train, ir_low_contrast_val,
        ir_noise_train, ir_noise_val,
        ir_stripe_noise_train, ir_stripe_noise_val
    ) = merge_and_print_paths(
        train_vi_noise_slight_path_list, val_vi_noise_slight_path_list,
        train_vi_noise_moderate_path_list, val_vi_noise_moderate_path_list,
        train_vi_noise_average_path_list, val_vi_noise_average_path_list,
        train_vi_noise_serious_path_list, val_vi_noise_serious_path_list,
        train_over_exposure_slight_path_list, val_over_exposure_slight_path_list,
        train_over_exposure_moderate_path_list, val_over_exposure_moderate_path_list,
        train_over_exposure_average_path_list, val_over_exposure_average_path_list,
        train_over_exposure_serious_path_list, val_over_exposure_serious_path_list,
        train_ir_low_contrast_slight_path_list, val_ir_low_contrast_slight_path_list,
        train_ir_low_contrast_moderate_path_list, val_ir_low_contrast_moderate_path_list,
        train_ir_low_contrast_average_path_list, val_ir_low_contrast_average_path_list,
        train_ir_low_contrast_serious_path_list, val_ir_low_contrast_serious_path_list,
        train_ir_noise_slight_path_list, val_ir_noise_slight_path_list,
        train_ir_noise_moderate_path_list, val_ir_noise_moderate_path_list,
        train_ir_noise_average_path_list, val_ir_noise_average_path_list,
        train_ir_noise_serious_path_list, val_ir_noise_serious_path_list,
        train_ir_stripe_noise_slight_path_list, val_ir_stripe_noise_slight_path_list,
        train_ir_stripe_noise_moderate_path_list, val_ir_stripe_noise_moderate_path_list,
        train_ir_stripe_noise_average_path_list, val_ir_stripe_noise_average_path_list,
        train_ir_stripe_noise_serious_path_list, val_ir_stripe_noise_serious_path_list
    )


    data_transform = {
        "train": T.Compose([T.RandomCrop(224),
                            T.RandomHorizontalFlip(0.5),
                            T.RandomVerticalFlip(0.5),
                            T.ToTensor()]),

        "val": T.Compose([T.Resize_16(),
                          T.ToTensor()])}

    train_dataset = PromptDataSet(train_vi_noise_path_list=vi_noise_train,
                                  val_vi_noise_path_list=vi_noise_val,
                                  train_over_exposure_path_list=over_exposure_train,
                                  val_over_exposure_path_list=over_exposure_val,
                                  train_ir_low_contrast_path_list=ir_low_contrast_train,
                                  val_ir_low_contrast_path_list=ir_low_contrast_val,
                                  train_ir_noise_path_list=ir_noise_train,
                                  val_ir_noise_path_list=ir_noise_val,
                                  train_ir_stripe_noise_path_list = ir_stripe_noise_train,
                                  val_ir_stripe_noise_path_list = ir_stripe_noise_val,
                                  phase="train",
                              transform=data_transform["train"])

    val_dataset = PromptDataSet(train_vi_noise_path_list=vi_noise_train,
                                  val_vi_noise_path_list=vi_noise_val,
                                  train_over_exposure_path_list=over_exposure_train,
                                  val_over_exposure_path_list=over_exposure_val,
                                  train_ir_low_contrast_path_list=ir_low_contrast_train,
                                  val_ir_low_contrast_path_list=ir_low_contrast_val,
                                  train_ir_noise_path_list=ir_noise_train,
                                  val_ir_noise_path_list=ir_noise_val,
                                  train_ir_stripe_noise_path_list = ir_stripe_noise_train,
                                  val_ir_stripe_noise_path_list = ir_stripe_noise_val,
                                  phase="val",
                              transform=data_transform["val"])
    
    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    model_clip, _ = clip.load("ViT-L/14@336px", device=device)
    model = create_model(model_clip).to(device)

    for param in model.model_clip.parameters():
        param.requires_grad = False

    if args.use_dp == True:
        model = torch.nn.DataParallel(model).cuda()

    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        weights_dict = torch.load(args.weights, map_location=device)["model"]
        print(model.load_state_dict(weights_dict, strict=False))


    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=5E-2)
    lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, warmup=True)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_epoch = checkpoint['epoch'] + 1

    for epoch in range(start_epoch, args.epochs):
        # train
        train_loss, train_ssim_loss, train_max_loss, train_color_loss, train_text_loss, lr = train_one_epoch(model=model,
                                              model_clip=model_clip,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                lr_scheduler=lr_scheduler,
                                                device=device,
                                                epoch=epoch)

        tb_writer.add_scalar("train_total_loss", train_loss, epoch)
        tb_writer.add_scalar("train_ssim_loss", train_ssim_loss, epoch)
        tb_writer.add_scalar("train_max_loss", train_max_loss, epoch)
        tb_writer.add_scalar("train_color_loss", train_color_loss, epoch)
        tb_writer.add_scalar("train_text_loss", train_text_loss, epoch)

        if epoch % args.val_every_epcho == 0:
            val_loss, val_ssim_loss, val_max_loss, val_color_loss, val_text_loss = evaluate(model=model,
                                         data_loader=val_loader,
                                         device=device,
                                         epoch=epoch, lr=lr, filefold_path=file_img_path)

            tb_writer.add_scalar("val_total_loss", val_loss, epoch)
            tb_writer.add_scalar("val_ssim_loss", val_ssim_loss, epoch)
            tb_writer.add_scalar("val_max_loss", val_max_loss, epoch)
            tb_writer.add_scalar("val_color_loss", val_color_loss, epoch)
            tb_writer.add_scalar("val_text_loss", val_text_loss, epoch)


            if val_loss < best_val_loss:
                if args.use_dp == True:
                    save_file = {"model": model.module.state_dict(),
                                 "optimizer": optimizer.state_dict(),
                                 "lr_scheduler": lr_scheduler.state_dict(),
                                 "epoch": epoch,
                                 "args": args}
                else:
                    save_file = {"model": model.state_dict(),
                                 "optimizer": optimizer.state_dict(),
                                 "lr_scheduler": lr_scheduler.state_dict(),
                                 "epoch": epoch,
                                 "args": args}
                torch.save(save_file, file_weights_path + "/" + "checkpoint.pth")
                best_val_loss = val_loss
            
            if args.use_dp == True:
                    save_file = {"model": model.module.state_dict(),
                                 "optimizer": optimizer.state_dict(),
                                 "lr_scheduler": lr_scheduler.state_dict(),
                                 "epoch": epoch,
                                 "args": args}
            else:
                    save_file = {"model": model.state_dict(),
                                 "optimizer": optimizer.state_dict(),
                                 "lr_scheduler": lr_scheduler.state_dict(),
                                 "epoch": epoch,
                                 "args": args}
            torch.save(save_file, file_weights_path + "/" + "checkpoint_lastest.pth")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', type=int, default=100)

    # set the appropriate batch-size value for your device
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--lr', type=float, default=0.0001)

    parser.add_argument('--vi_noise_slight_path', type=str, default="./VI_Noise/VI_Noise_slight")
    parser.add_argument('--vi_noise_moderate_path', type=str, default="./VI_Noise/VI_Noise_moderate")
    parser.add_argument('--vi_noise_average_path', type=str, default="./VI_Noise/VI_Noise_average")
    parser.add_argument('--vi_noise_serious_path', type=str, default="./VI_Noise/VI_Noise_extreme")

    parser.add_argument('--over_exposure_slight_path', type=str, default="./VI_Over_exposure/VI_Over_exposure_slight")
    parser.add_argument('--over_exposure_moderate_path', type=str, default="./VI_Over_exposure/VI_Over_exposure_moderate")
    parser.add_argument('--over_exposure_average_path', type=str, default="./VI_Over_exposure/VI_Over_exposure_average")
    parser.add_argument('--over_exposure_serious_path', type=str, default="./VI_Over_exposure/VI_Over_exposure_extreme")



    parser.add_argument('--ir_low_contrast_slight_path', type=str, default="./IR_Low_contrast/IR_Low_contrast_slight")
    parser.add_argument('--ir_low_contrast_moderate_path', type=str, default="./IR_Low_contrast/IR_Low_contrast_moderate")
    parser.add_argument('--ir_low_contrast_average_path', type=str, default="./IR_Low_contrast/IR_Low_contrast_average")
    parser.add_argument('--ir_low_contrast_serious_path', type=str, default="./IR_Low_contrast/IR_Low_contrast_extreme")


    parser.add_argument('--ir_noise_slight_path', type=str, default="./IR_Noise/IR_Noise_slight")
    parser.add_argument('--ir_noise_moderate_path', type=str, default="./IR_Noise/IR_Noise_moderate")
    parser.add_argument('--ir_noise_average_path', type=str, default="./IR_Noise/IR_Noise_average")
    parser.add_argument('--ir_noise_serious_path', type=str, default="./IR_Noise/IR_Noise_extreme")

    parser.add_argument('--ir_stripe_noise_slight_path', type=str, default="./IR_Stripe_noise/IR_Stripe_noise_slight")
    parser.add_argument('--ir_stripe_noise_moderate_path', type=str, default="./IR_Stripe_noise/IR_Stripe_noise_moderate")
    parser.add_argument('--ir_stripe_noise_average_path', type=str, default="./IR_Stripe_noise/IR_Stripe_noise_average")
    parser.add_argument('--ir_stripe_noise_serious_path', type=str, default="./IR_Stripe_noise/IR_Stripe_noise_extreme")


    parser.add_argument('--weights', type=str, default='',
                        help='initial weights path')
    parser.add_argument('--val_every_epcho', type=int, default=1, help='val every epcho')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--use_dp', default = False, help='use dp-multigpus')
    parser.add_argument('--device', default='cuda', help='device (i.e. cuda or cpu)')
    parser.add_argument('--gpu_id', default='0', help='device id (i.e. 0, 1, 2 or 3)')

    opt = parser.parse_args()

    main(opt)
