"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import argparse
import os
import random
import sys
import copy

import numpy as np
import torch
import torch.backends.cudnn as cudnn

import lavis.tasks as tasks
from lavis.common.config import Config
from lavis.common.dist_utils import get_rank, init_distributed_mode
from lavis.common.logger import setup_logger
from lavis.common.optims import (
    LinearWarmupCosineLRScheduler,
    LinearWarmupStepLRScheduler,
)
from lavis.common.registry import registry
from lavis.common.utils import now

# imports modules for registration
from lavis.datasets.builders import *
from lavis.models import *
from lavis.processors import *
from lavis.runners import *
from lavis.tasks import *

from lavis.constants import SQA_PARTITION_CKPT


def parse_args():
    parser = argparse.ArgumentParser(description="Training")

    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
    parser.add_argument("--init_lr", required=False, help="path to configuration file.")
    parser.add_argument("--min_lr", required=False, help="path to configuration file.")
    parser.add_argument("--max_epoch", type=int, required=False, help="path to configuration file.")
    parser.add_argument("--batch_size_train", type=int, required=False, help="path to configuration file.")
    parser.add_argument("--batch_size_eval", type=int, required=False, help="path to configuration file.")
    parser.add_argument("--iters_per_epoch", type=int, required=False, help="path to configuration file.")
    parser.add_argument("--debug_sign", required=False, help="path to configuration file.")
    parser.add_argument("--ada_rank", type=int, required=False, help="rank of fednano")
    parser.add_argument("--ckpt", required=False, help="path to checkpoint file.")
    parser.add_argument("--evaluation_only", action="store_true", required=False, help="")
    parser.add_argument("--output_dir", required=False, help="path to output directory.")
    parser.add_argument("--zero_init_ada", action="store_true", required=False)
    # parser.add_argument("--uniform_init_ada", action="store_true", required=False)
    parser.add_argument("--pretrained_init_ada", action="store_true", required=False, help="use the same weight as linear projector.")
    parser.add_argument("--bypass", action="store_true", required=False)
    parser.add_argument("--sequential", action="store_true", required=False)
    parser.add_argument("--ada_linear", action="store_true", required=False)
    parser.add_argument("--ada_downup", action="store_true", required=False)
    parser.add_argument("--FedAvg", action="store_true", required=False)

    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
        "in xxx=yyy format will be merged into config file (deprecate), "
        "change to --cfg-options instead.",
    )

    args = parser.parse_args()
    # if 'LOCAL_RANK' not in os.environ:
    #     os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args


def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True

def convert_vector_to_params(vector_dict, params_dict):
        assert vector_dict.keys() == params_dict.keys(), "The keys of vector_dict and params_dict should be the same."
        for name, param in vector_dict.items():
            params_dict[name].data = param.view_as(params_dict[name]).data

def get_runner_class(cfg, runner_name=None):
    """
    Get runner class from config. Default to epoch-based runner.
    """
    if runner_name:
        runner_cls = registry.get_runner_class(runner_name)
    else:
        runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))

    return runner_cls


def main():
    # allow auto-dl completes on main process without timeout when using NCCL backend.
    # os.environ["NCCL_BLOCKING_WAIT"] = "1"

    # set before init_distributed_mode() to ensure the same job_id shared across all ranks.
    
    cfg = Config(parse_args())
                        
    if getattr(cfg.args, 'init_lr', None):
        cfg.run_cfg["init_lr"] = cfg.args.init_lr
    if getattr(cfg.args, 'min_lr', None):
        cfg.run_cfg["min_lr"] = cfg.args.min_lr
    if getattr(cfg.args, 'max_epoch', None):
        cfg.run_cfg["max_epoch"] = cfg.args.max_epoch
    if getattr(cfg.args, 'batch_size_train', None):
        cfg.run_cfg["batch_size_train"] = cfg.args.batch_size_train
    if getattr(cfg.args, 'batch_size_eval', None):
        cfg.run_cfg["batch_size_eval"] = cfg.args.batch_size_eval
    if getattr(cfg.args, 'iters_per_epoch', None):
        cfg.run_cfg["iters_per_epoch"] = cfg.args.iters_per_epoch
    if getattr(cfg.args, 'ada_rank', None):
        cfg.model_cfg["ada_rank"] = cfg.args.ada_rank
    if getattr(cfg.args, 'ckpt', None):
        cfg.model_cfg["pretrained"] = cfg.args.ckpt
    if getattr(cfg.args, 'evaluation_only', None):
        cfg.run_cfg["evaluate"] = cfg.args.evaluation_only
    if getattr(cfg.args, 'output_dir', None):
        cfg.run_cfg["output_dir"] = cfg.args.output_dir
        os.makedirs(cfg.run_cfg["output_dir"], exist_ok=True)
    if getattr(cfg.args, 'use_fednano_I', None):
        cfg.model_cfg["use_fednano_I"] = cfg.args.use_fednano_I
    if getattr(cfg.args, 'use_fednano_T', None):
        cfg.model_cfg["use_fednano_T"] = cfg.args.use_fednano_T

    cfg.model_cfg["zero_init_ada"] = cfg.args.zero_init_ada
    # cfg.model_cfg["uniform_init_ada"] = cfg.args.uniform_init_ada
    cfg.model_cfg["pretrained_init_ada"] = cfg.args.pretrained_init_ada
    cfg.model_cfg["bypass"] = cfg.args.bypass
    cfg.model_cfg["sequential"] = cfg.args.sequential
    cfg.model_cfg["ada_linear"] = cfg.args.ada_linear
    cfg.model_cfg["ada_downup"] = cfg.args.ada_downup
    
    job_id = now()+f"_init_lr_{cfg.run_cfg['init_lr']}_maxepoch_{cfg.run_cfg['max_epoch']}_bs_{cfg.run_cfg['batch_size_train']}_minlr_{cfg.run_cfg['min_lr']}"
    # fednano
    if getattr(cfg.model_cfg, 'use_fednano_I', None) == True:
        job_id += "_adaI"
    if getattr(cfg.model_cfg, 'use_fednano_T', None) == True:
        job_id += "_adaT"
    if getattr(cfg.model_cfg, 'ada_rank', None):
        job_id += f"_rank_{cfg.model_cfg['ada_rank']}"
    
    dataset_name = list(cfg.datasets_cfg.keys())[0] # only support one dataset
    
    job_id += "_0init" if cfg.model_cfg["zero_init_ada"] else ""
    job_id += "_pretrainedinit" if cfg.model_cfg["pretrained_init_ada"] else ""
    job_id += "_bypass" if cfg.model_cfg["bypass"] else ""
    job_id += "_sequential" if cfg.model_cfg["sequential"] else ""
    job_id += "_ada_linear" if cfg.model_cfg["ada_linear"] else ""
    job_id += "_ada_downup" if cfg.model_cfg["ada_downup"] else ""

    if getattr(cfg.run_cfg, 'evaluate', None):
        ckpt_path = cfg.model_cfg.get("pretrained", None)
        ckpt_name = ckpt_path.split("/")[-2]
        job_id = now() + '_' + ckpt_name
    
    if getattr(cfg.args, 'debug_sign', None):
        job_id += f"_{cfg.args.debug_sign}"

    print(f"Job ID {job_id}")

    # init_distributed_mode(cfg.run_cfg)
    
    print(cfg.run_cfg)

    setup_seeds(cfg)

    # set after init_distributed_mode() to only log on master.
    setup_logger()

    cfg.pretty_print()


    ######################################

    fed_datasets = list(cfg.datasets_cfg.keys())
    num_client = len(fed_datasets)

    # Load the base model
    base_task = tasks.setup_task_fed(cfg, fed_datasets[0])
    base_model = base_task.build_model(cfg)

    # Get the trainable params dict
    trainable_params_names = []
    trainable_params_state_dict = {}
    for name, param in base_model.named_parameters():
        if param.requires_grad:
            trainable_params_names.append(name)
            trainable_params_state_dict[name] = param

    trainset_size_list = []
    fisher_sum = None
    weights_sum = None
    fisher_weights_product_sum = None
    # For loop starts
    for i in range(num_client):
        # Load one client's model
        dataset_name = fed_datasets[i]
        client_task = tasks.setup_task_fed(cfg, dataset_name)
        client_dataset = client_task.build_datasets(cfg, dataset_name)
        trainset_size_list.append(len(client_dataset[dataset_name]["train"].annotation))
        
        ckpt_path = SQA_PARTITION_CKPT[dataset_name]
        base_model.load_checkpoint(ckpt_path)
        runner = get_runner_class(cfg)(cfg=cfg, job_id=job_id, task=client_task, model=base_model, datasets=client_dataset, dataset_name=dataset_name)
        # Compute its FIM, add to sum
        fisher, trainable_params_vector_dict, fisher_weights_product = runner.compute_fedfisher()
        if fisher_sum is None:
            fisher_sum = fisher
        else:
            for name, param in fisher_sum.items():
                fisher_sum[name] += fisher[name]

        # Add params to params_sum to compute the W^0
        if weights_sum is None:
            weights_sum = trainable_params_vector_dict
        else:
            for name, param in weights_sum.items():
                weights_sum[name] += trainable_params_vector_dict[name]

        # Compute the product of FIM and its trainable params, add to sum
        if fisher_weights_product_sum is None:
            fisher_weights_product_sum = fisher_weights_product
        else:
            for name, param in fisher_weights_product_sum.items():
                fisher_weights_product_sum[name] += fisher_weights_product[name]
    
    # For loop ends
    # Average to compute the W^0
    for name, param in weights_sum.items():
        weights_sum[name] /= num_client

    # Iteratively update the W^t
    trainable_params_state_dict_copy = copy.deepcopy(trainable_params_state_dict)
    w_t = weights_sum
    T = 200000
    eta = 0.05

    with torch.no_grad():
        clients_best_acc = [0.0] * num_client
        clients_best_epoch = [0] * num_client
        mom = {}
        delta = {}
        for name, param in trainable_params_state_dict.items():
            mom[name] = 0*w_t[name]
            delta[name] = 0*w_t[name]
        # Adam
        for k in range(T):
            for name, param in trainable_params_state_dict.items():
                v = fisher_sum[name]*w_t[name] - fisher_weights_product_sum[name]
                mom[name] = v + mom[name]*0.9
                delta[name] = v*v + delta[name]*0.99
                w_t[name] = w_t[name] - eta*mom[name]*1/(torch.sqrt(delta[name]) + 0.01)

            if(k%1000==0):   
                convert_vector_to_params(w_t, trainable_params_state_dict_copy)
                base_model.load_state_dict(trainable_params_state_dict_copy, strict=False) 
                for i in range(num_client):
                    client_dataset = fed_datasets[i]
                    task = tasks.setup_task_fed(cfg, client_dataset)
                    dataset = task.build_datasets(cfg, client_dataset)
                    # eval for the current client
                    runner = get_runner_class(cfg, runner_name="runner_comm_after_each_epoch")(cfg=cfg, job_id=job_id, task=task, model=base_model, datasets=dataset, client_dataset=client_dataset)
                    runner.fed_eval(cur_epoch=k, client_dataset=client_dataset, clients_best_acc=clients_best_acc, clients_best_epoch=clients_best_epoch, client_id=i, before_communication=False)


if __name__ == "__main__":
    main()
