import os
import numpy as np
import torch
from torch import nn
from transformers.pytorch_utils import Conv1D
from torch.nn import functional as F
import pickle
from heads import get_classification_head
import re
import logging
from tqdm import tqdm
from vision_datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle
from vision_datasets.registry import get_dataset
from vision_datasets.registry import split_train_into_train_val

# export PYTHONPATH="$PYTHONPATH:$PWD"

import time
import sys
root = '/data/common/task-arithmetic'
sys.path.append(root)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 

from eval import eval_single_dataset
from args import parse_arguments


def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD

model = 'ViT-B-32'
args = parse_arguments()
args.data_location = root + '/data'
args.model = model
args.save = root + '/task_vectors_checkpoints/' + model
args.logs_path = '../logs/' + model
pretrained_checkpoint = root+'/task_vectors_checkpoints/'+model+'/zeroshot.pt'

str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
log = create_log_dir(args.logs_path, 'log_regmean_{}.txt'.format(str_time_))

final_model = torch.load(pretrained_checkpoint)
pretrained_model = torch.load(pretrained_checkpoint)
pretrained_model_dic = pretrained_model.state_dict()

trainable_params = {}
frozen = ["model.positional_embedding", "model.text_projection", "model.logit_scale", "model.token_embedding.weight", "model.ln_final.weight", "model.ln_final.bias"]
for k, v in pretrained_model_dic.items():
    if k not in frozen:
        trainable_params[k] = v


def filter_modules_by_regex(base_module, include_patterns, include_type):
    modules = {}
    for name, module in base_module.named_modules():
        valid_name = not include_patterns or any(
            [re.match(patt, name) for patt in include_patterns]
        )
        valid_type = not include_type or any(
            [isinstance(module, md_cls) for md_cls in include_type]
        )
        if valid_type and valid_name:
            modules[name] = module
    return modules


def compute_grams(finetuned_model, train_dataloader):
    covs = {}
    xn = {}

    def get_grams(name):
        def hook(module, input, output):
            """
            Note: adhere to signature of hook functions
            """
            x = input[0].detach()  # $[b,t,h]
            x = x.view(-1, x.size(-1))
            xtx = torch.matmul(x.transpose(0, 1), x)  # [h,h]
            if name not in covs:
                covs[name] = xtx / x.size(0)
                xn[name] = x.size(0)
            else:
                covs[name] = (covs[name] * xn[name] + xtx) / (x.size(0) + xn[name])
                xn[name] += x.size(0)

        return hook

    model = finetuned_model.to(device)
    linear_modules = filter_modules_by_regex(
        model, None, [nn.Linear, nn.Conv1d, Conv1D]
    )
    # print("Linear modules: {}".format(linear_modules))
    handles = []
    for name, module in linear_modules.items():
        handle = module.register_forward_hook(get_grams(name))
        handles.append(handle)

    # mark cov modules as special
    covs["meta_info"] = {
        "conv1d": [
            n
            for n, m in filter_modules_by_regex(
                model, None, [nn.Conv1d, Conv1D]
            ).items()
        ]
    }

    n_step = 1000
    total = n_step if n_step > 0 else len(train_dataloader)
    for step, inputs in tqdm(
        enumerate(train_dataloader), total=total, desc="Computing gram matrix"
    ):
        if n_step > 0 and step == n_step:
            break
        # print(inputs['labels'])
        inputs = maybe_dictionarize(inputs)
        inputs = inputs['images'].to(device)
        # _ = forward_model_pass(model, inputs)
        _ = model(inputs)

    for handle in handles:
        handle.remove()

    return covs


def reduce_non_diag(cov_mat, a):
    diag_weight = torch.diag(torch.ones(cov_mat.size(0)) - a).to(cov_mat.device)
    non_diag_weight = torch.zeros_like(diag_weight).fill_(a)
    weight = diag_weight + non_diag_weight
    ret = cov_mat * weight
    return ret


def regmean_merge(all_params, all_grams):
    avg_params = {}
    n_model = len(all_grams)
    for name in all_params:
        h_avged = False
        if name.endswith('.weight'):
            print(f'Regmean: {name}')
            module_name = name[:-len('.weight')]
            if module_name in all_grams[0]:
                gram_m_ws, grams = [], []

                for model_id, model_grams in enumerate(all_grams):
                    param_grams = model_grams[module_name]

                    # for roberta we dont need this; but it is important for deberta and t5
                    param_grams = reduce_non_diag(param_grams, a=0.9)

                    param = all_params[name][model_id]
                    gram_m_ws.append(torch.matmul(param_grams, param.transpose(0,1)))
                    grams.append(param_grams)
                sum_gram = sum(grams)
                sum_gram_m_ws = sum(gram_m_ws)
                sum_gram_inv = torch.inverse(sum_gram)
                wt = torch.matmul(sum_gram_inv, sum_gram_m_ws)
                w = wt.transpose(0,1)
                avg_params[name] = w
                h_avged = True
        if not h_avged: # if not averaged with regmean, then do simple avg
            avg_params[name] = torch.stack(all_params[name],0).mean(0)
           
    return avg_params


all_grams, finetuned_models = [], []
for dataset_name in exam_datasets:
    finetuned_checkpoint = root+'/task_vectors_checkpoints/'+model+'/'+dataset_name+'/finetuned.pt'
    try:
        finetuned_model = torch.load(finetuned_checkpoint)
    except:
        finetuned_model = pickle.load(open(finetuned_checkpoint, 'rb'))

    base_dataset = get_dataset(dataset_name, final_model.val_preprocess, location=args.data_location, batch_size=args.batch_size)
    dataset = split_train_into_train_val(
                base_dataset, dataset_name, args.batch_size, num_workers=2, val_fraction=0.1, max_val_samples=5000)
    valset = dataset.test_dataset

    val_dataloader = torch.utils.data.DataLoader(
                valset,
                batch_size=args.batch_size,
                shuffle=True,
                num_workers=2
            )

    classifier_head = get_classification_head(args, dataset_name)
    with torch.no_grad():
        grams = compute_grams(finetuned_model, val_dataloader)
    all_grams.append(grams)
    finetuned_models.append(finetuned_model.state_dict())


frozen = ["model.positional_embedding", "model.text_projection", "model.logit_scale", "model.token_embedding.weight", "model.ln_final.weight", "model.ln_final.bias"]
params = {}
for local_model in finetuned_models:
    n2p = {k: v for k,v in local_model.items()}
    merge_param_names = []
    for n in n2p:
        if n not in frozen:
            merge_param_names.append(n)
    for n in merge_param_names:
        if n not in params:
            params[n] = []
        params[n].append(n2p[n])


image_encoder = torch.load(pretrained_checkpoint)
merged_state_dict = regmean_merge(params, all_grams)
image_encoder.load_state_dict(merged_state_dict, strict=False)
image_encoder.save(root+'/merged_models/' + model + f'/regmean.pt')

accs = []
for dataset in exam_datasets:
    metrics = eval_single_dataset(image_encoder, dataset, args)
    log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%')
    accs.append(metrics.get('top1')*100)
log.info('Avg ACC:' + str(np.mean(accs)) + '%')