"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

"""
adapted from https://github.com/wooseungw/blip2
"""
import argparse
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="2"
import sys

import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import transforms as torchvision_transforms

import lavis.tasks as tasks
# from lavis.common.config import Config
from ..config.config import Config
from lavis.common.dist_utils import get_rank, init_distributed_mode
from lavis.common.logger import setup_logger
from lavis.common.optims import (
    LinearWarmupCosineLRScheduler,
    LinearWarmupStepLRScheduler,
)
from lavis.common.registry import registry
from lavis.common.utils import now

from lavis.datasets.builders import *
from lavis.models import *
from lavis.processors import *
from lavis.runners import *
from lavis.tasks import *

### Model
from model.modeling_fgclip_adapter import FGCLIPModel
from tqdm import tqdm

from transformers import (
    AutoImageProcessor,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModel,
)

def parse_args():
    parser = argparse.ArgumentParser(description="Training")

    parser.add_argument("--cfg-path", 
                        default="model/config/default.yaml",
                        required=False, 
                        help="path to configuration file.")
    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )

    args = parser.parse_args()

    return args


def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True

def get_runner_class(cfg):
    """
    Get runner class from config. Default to epoch-based runner.
    """
    runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))

    return runner_cls

def get_optimizer(cfg, model):
    run = cfg.config['run']
    lr = float(run["init_lr"])
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
    )
    return optimizer

def get_lr_scheduler(cfg, optimizer):
    run = cfg.config['run']
    max_epoch = run["max_epoch"]
    init_lr = float(run["init_lr"])
    min_lr = float(run["min_lr"])
    warmup_steps = run["warmup_steps"]
    warmup_start_lr = run["warmup_lr"]
    lr_scheduler = LinearWarmupCosineLRScheduler(
        optimizer=optimizer,
        max_epoch=max_epoch,
        min_lr=min_lr,
        init_lr=init_lr,
        warmup_steps=warmup_steps,
        warmup_start_lr=warmup_start_lr,
    )
    return lr_scheduler

def get_last_epoch_loss(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        check_epoch = False
        for line in reversed(lines):
            if 'checkpoint' in line:
                check_epoch = True
                continue
            if not check_epoch:
                continue
            if 'loss:' in line:
                parts = line.split(',')
                for part in parts:
                    if 'loss:' in part:
                        loss_value = float(part.split(':')[1].strip())
                        return loss_value
    return None 

def train(cfg, model, dataloader, optimizer, 
          lr_scheduler = None, 
          last_best_loss = None,
          **kwargs):
    run = cfg.config['run']
    model.train()

    best_loss = 10000000.0 if last_best_loss is None else last_best_loss
    print(f"init best loss: {best_loss}")

    losses = []
    epoches = []
    count = 0

    output_dir = os.path.join("lavis/", run["output_dir"])
    os.makedirs(output_dir, exist_ok=True)

    if 0:
        pass
    else: # style2image retrival
        for epoch in range(run["max_epoch"]):
            temp_loss = {}

            for itr, samples in enumerate(tqdm(dataloader)):
                results = model(samples, mode=run["forward_mode"], **kwargs)
                loss_dict = {}
                for k,v in results.items():
                    if "loss" in k:
                        loss_dict[k] = v
                        temp_loss[k] = []

                loss = results.loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if lr_scheduler is not None:
                    lr_scheduler.step(epoch, itr)
                
                for k,v in loss_dict.items():
                    temp_loss[k].append(v.detach().cpu().numpy())

                print_output = [f"lr: {optimizer.param_groups[0]['lr']:.6f}"]
                for k,v in loss_dict.items():
                    print_output.append(f"{k}: {v:.6f}")
                print(f" ".join(print_output))

            if len(temp_loss) != 0:
                res_dict = {}
                for k,v in temp_loss.items():
                    res_dict[k] = round(sum(v)/len(v), 6)
                res = res_dict["loss"]
                print("epoch_{} loss is {}.".format(epoch, res))

                log_path = os.path.join("lavis/", run["output_dir"], 'log.txt')
                with open(log_path, "a") as f:
                    content = [f"epoch:{epoch},"]
                    for k,v in res_dict.items():
                        content.append(f"{k}:{v},")
                    f.write(" ".join(content) + '\n')
            
            losses.append(res)
            epoches.append(epoch)

            if res < best_loss:
                best_loss = res
                save_obj = model.state_dict()

                log_path = os.path.join("lavis/", run["output_dir"], 'log.txt')
                with open(log_path, "a") as f:
                    f.write(f"save checkpoint at epoch {epoch}\n")

                torch.save(save_obj, os.path.join("lavis/", run["output_dir"], 'checkpoint.pth'))
                count = 0
            else:
                count += 1
            if best_loss < 0.0001 or count > 30:
                break
    return losses, epoches

def train_main():
    # allow auto-dl completes on main process without timeout when using NCCL backend.
    # os.environ["NCCL_BLOCKING_WAIT"] = "1"

    # set before init_distributed_mode() to ensure the same job_id shared across all ranks.
    job_id = now()

    cfg = Config(parse_args())

    init_distributed_mode(cfg.run_cfg)

    setup_seeds(cfg)

    # set after init_distributed_mode() to only log on master.
    setup_logger()

    cfg.pretty_print()

    ### build model
    model_root = cfg.config['run']['model_root']
    image_size = 224
    model = FGCLIPModel.from_pretrained(model_root,
                                        # ignore_mismatched_sizes=True,
                                        local_files_only=True,
                                        use_safetensors=True)

    model.cuda()

    ft_path = os.path.join("lavis/", cfg.config['run']["output_dir"], "checkpoint.pth")
    last_best_loss = None
    if os.path.exists(ft_path):
        model.load_state_dict(torch.load(ft_path, map_location=model.device))
        print(f"--- sucessfully load checkpoint from {ft_path}")
        txt_path = os.path.join("lavis/", cfg.config['run']["output_dir"], "log.txt")
        last_best_loss = get_last_epoch_loss(txt_path)

    ### build dataset
    tokenizer = AutoTokenizer.from_pretrained(model_root)
    image_processor = AutoImageProcessor.from_pretrained(model_root)

    dataset_type = cfg.config['run']["dataset_type"]
    if dataset_type == "DSR":
        print(f"training at DSR")
        from model.data import StyleImageTextDataset
        train_dataset_path = "/train_dataset_path"
        train_json_path = "/train_json_path"

        style_dataset = StyleImageTextDataset(train_dataset_path, train_json_path, image_processor)
    elif dataset_type == "flickr30k":
        print(f"training at flickr30k")
        from model.f30k_data import StyleImageTextDataset
        train_dataset_path = "/train_dataset_path"
        train_json_path = "/train_json_path"
        style_dataset = StyleImageTextDataset(train_dataset_path, train_json_path, 
                                            image_transform=image_processor,
                                            select_sketch=True)
    else:
        raise ValueError(f"invalid dataset type of {dataset_type}")

        
    train_loader = DataLoader(dataset=style_dataset, batch_size=cfg.config['run']['batch_size_train'], # cfg.config['run']['num_workers']
                            num_workers=1, shuffle=True)

    optimizer = get_optimizer(cfg, model)
    lr_scheduler = get_lr_scheduler(cfg, optimizer)

    loss, epochs = train(cfg, model, train_loader, optimizer,
                         lr_scheduler=lr_scheduler,
                         tokenizer=tokenizer,
                         last_best_loss=last_best_loss,)


if __name__ == "__main__":
    train_main()
