import torch
from collections import defaultdict, OrderedDict
import tqdm
import re
import torch.nn as nn
import copy
import sparsify
import utils
import sys
import transformers
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer
import os
import functools
from collections import defaultdict, OrderedDict
from param import param
import torch.nn.functional as F 
import torch
from collections import defaultdict
import numpy as np
from merge import MergingMethod
from model import load_classifier
import inspect
import datasets
import pandas as pd
import utils

args = None
DEVICE='cuda:0'

@torch.inference_mode()
def extract_twin_vector(
    model: nn.Module, 
    merged_model: param,
    mask_rate: float,
    mask_strategy: str = 'magnitude',
):
    # \theta^t - \theta*
    twin_vector = (model - merged_model).map(
        lambda n,p: getattr(sparsify, mask_strategy)(
            p, 
            1 - mask_rate,
        ),
        desc=mask_strategy
    ) 
    return twin_vector

@torch.inference_mode()
def run_twin_vector(
    args,
):
    # NOTICE: we don't consider 0 expert scenearies, must at lest one be chosen (because we don't have router training dataset)
    # \theta_t => args.models_name args.models_to_merge
    # \theta^* => args.src_merge
    # twin_vector => args.src_twin 配套相应的 args.data_path (len(router) == len(args.src_twin))
    import eval_glue 

    if len(args.src_merge) == 1:
        raise Exception('parameter Error')

    if args.exclude_param and len(args.exclude_param):
        filter_func = lambda n,p : not any([
            re.match(exclude_pattern, n) 
            for exclude_pattern in args.exclude_param
        ])
    
    # \theta_t
    # for classifier head (placeholder)
    models_finetuned = {
        name: load_classifier(
            eval_glue.model_path_template.format(name=name)
        ).to(DEVICE)
        for name in args.models_name
    }
    # \theta_*
    models_to_merge = [
        models_finetuned[name].to(DEVICE)
        for name in args.src_merge
    ]
    # \theta_0
    base_model = load_classifier(args.base_model).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    
    args.base_model = param(base_model)
    args.models_to_merge = [ param(m) for m in models_to_merge ]

    # exclude_param
    for model in args.models_to_merge:
        model.filter(filter_func)
    args.base_model.filter(filter_func)

    # get merged model first
    merger = MergingMethod(**args)
    merged_param = merger.task_arithmetic_search(**args)

    # merged_param
    metrics = {
        "_method": 'task_arithmetic_search',
        "scaling": ','.join([str(i) for i in args['scaling']]),
        **{
            f"_{k}": args[k] for k in [ 'mask_rate', 'mask_strategy', 'mask_scale', 'src_twin', 'src_merge' ]
        }
    }
    metrics['_mask_rate'] = 100*float(f"{metrics['_mask_rate']:.4f}")
    metrics['_src_twin'] = '+'.join(metrics['_src_twin'])
    metrics['_src_merge'] = '+'.join(metrics['_src_merge'])

    # tv_t
    twin_vector = {}
    data_id = None
    for data_name in (args.src_twin):
        # TODO: 默认是按data_id来排序， 需要和data_item['router_prob']对应
        data_id = eval_glue.glue_data_id_map[data_name]
        twin_vector[data_id] = extract_twin_vector(
            model=models_to_merge[data_id], 
            merged_model=merged_param,
            mask_rate=args.mask_rate,
            mask_strategy=args.mask_strategy,
        )

    if len(args.src_twin) == 1:
        _infer_param = merged_param  + twin_vector[data_id]

    data = utils.from_json(args.data_path)
    eval_pred = defaultdict(lambda: defaultdict(list))
    for data_item in tqdm.tqdm(data, desc='infer glue'):
        data_id = data_item['dataset_ids']
        data_name = list(eval_glue.glue_data_id_map.keys())[data_id]

        if len(args.src_twin) != 1:

            tv_weights = F.softmax(torch.tensor(data_item['router_prob']), dim=0)

            assert len(tv_weights) == len(args.src_twin)

            twin_sum = sum([ w*tv for tv, w in zip(twin_vector.values(),tv_weights) ])
            _infer_param =  merged_param  + twin_sum
        
        # print([ (n,p.dtype) for n,p in merged_params.items() ])

        def calculate_logits(data_item):
            model = models_finetuned[data_name]
            score = torch.func.functional_call(
                model, 
                _infer_param.param_dict, 
                args=(
                    torch.tensor(data_item['input_ids']).unsqueeze(0).to(model.device),
                    torch.tensor(data_item['attention_mask']).unsqueeze(0).to(model.device),
                ),
            ).logits.cpu().numpy()

            return score
    
        eval_pred[data_name]['predictions'].append(calculate_logits(data_item))
        eval_pred[data_name]['label_ids'].append(data_item['label'])

    for data_name in eval_pred.keys():
        
        ans = eval_glue.compute_single_metrics(
            utils.SimpleNamespace(
                predictions=np.concatenate(eval_pred[data_name]['predictions']),
                label_ids=np.array(eval_pred[data_name]['label_ids'])
            ), data_name
        )['averaged_scores']
        metrics[data_name] = 100*float(f"{ans:.4f}")
    
    utils.save_excel(metrics, args.outdir)

@utils.deprecated
@torch.inference_mode()
def run_twin_vector_iod_deprecated(
    args,
):
    # use \theta^*_t to deal with other tasks 
    import eval_glue 

    if args.exclude_param and len(args.exclude_param):
        filter_func = lambda n,p : not any([
            re.match(exclude_pattern, n) 
            for exclude_pattern in args.exclude_param
        ])
    
    # \theta_t
    models_to_merge = {
        name: load_classifier(model).to(DEVICE)
        for name, model in zip(args.models_name, args.models_to_merge)
    }
    # \theta_0
    base_model = load_classifier(args.base_model).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    
    args.base_model = param(base_model)
    args.models_to_merge = [param(m) for m in models_to_merge.values()]

    # exclude_param
    for model in args.models_to_merge:
        model.filter(filter_func)
    args.base_model.filter(filter_func)

    # get merged model first
    merger = MergingMethod(**args)
    merge_method = getattr(merger, args.merge_method)
    merged_param = merge_method(**args)

    # merged_param
    metrics = {
        "model": args.src_name + '-' + args.merge_method,
        **{
            k: args[k] for k in [ 'mask_rate', 'mask_strategy', 'scaling', 'mask_scale', 'src_name' ]
        }
    }

    data_id = eval_glue.glue_data_id_map[args.src_name]
    twin_vector = extract_twin_vector(
        model=models_to_merge[args.src_name], 
        merged_model=merged_param,
        mask_rate=args.mask_rate,
        mask_strategy=args.mask_strategy,
    )

    data = utils.from_json(args.iod_path)

    # select tgt tasks
    dataset_list = defaultdict(list)
    for data_item in (data):
        data_id = data_item['dataset_ids']
        data_name = list(eval_glue.glue_data_id_map.keys())[data_id]
        # if data_name in args.tgt_name:
        dataset_list[data_name].append(data_item)

    _infer_param =  merged_param  + twin_vector

    for data_name, dataset in dataset_list.items():

        dataset = datasets.Dataset.from_pandas(pd.DataFrame(dataset))

        def calculate_logits(data_item):

            model = models_to_merge[data_name]

            input_ids = torch.nn.utils.rnn.pad_sequence(
                [torch.tensor(d) for d in data_item['input_ids']], 
                batch_first=True, 
                padding_value=tokenizer.pad_token_id,
            )
            attention_mask = torch.nn.utils.rnn.pad_sequence(
                [torch.tensor(d) for d in data_item['attention_mask']],  
                batch_first=True, 
                padding_value=0,
            )

            score = torch.func.functional_call(
                model, 
                _infer_param.param_dict, 
                args=(
                    input_ids.to(model.device),
                    attention_mask.to(model.device),
                ),
            ).logits.cpu().numpy()

            return {
                'predictions': score,
                'label_ids': data_item['label']
            }
    
        dataset = dataset.map(
            lambda x: calculate_logits(x),
            batched=True,
            batch_size=4,
        )
        
        ans = eval_glue.compute_single_metrics(
            utils.SimpleNamespace(
                predictions=torch.tensor(dataset['predictions']),
                label_ids=np.array(dataset['label_ids'])
            ), data_name
        )['averaged_scores']
        metrics[data_name] = 100*float(f"{ans:.4f}")
    
    utils.save_excel(metrics, args.outdir)


def run_merge(
    *, 
    # terminal 送的参数最高优先级，按是否为None判断
    models_to_merge: list[str], 
    models_name: list[str],
    data_path: str,
    src_merge: list[str], 
    src_twin: list[str], 
    yaml_file: str = None,
    model_placeholder: str = None, 
    model_loader: str = None,
    eval_func: str = None,
    dtype: str = None,
    exclude_param: list[str] = None, 
    load_head: bool = None,
    seed: int=10,
    base_model: str = 'roberta-base',
    # for task-arithmetic:
    scaling: list[float] = None,
    # for dare-merge:
    mask_rate: float = None,
    mask_scale: float = None,
    mask_strategy: str = None,
    outdir: str = None,
):

    global args
    import inspect
    keys, _, _, values = inspect.getargvalues(inspect.currentframe())

    utils.fix_seed(seed)
    os.makedirs(outdir, exist_ok=True)

    merge_config = utils.from_yaml(yaml_file)   
    # 读取yaml内的参数 
    args = {
        k: values.get(k, merge_config.get(k)) 
        for k in set(keys).union(merge_config)
    }
    # 以命令行送的为高优先级
    args = {
        k: merge_config.get(k, None)
        if args[k] is None else args[k]
        for k in args.keys()
    }
    args = utils.SimpleNamespace(**args)
    # args.scaling = [float(w) for w in args.scaling.split(',')]

    print('>>> args\n', args)

    run_twin_vector(
        args,
    )


if __name__ == '__main__':
    import defopt
    try:
        defopt.run(run_merge)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)