import os
import json
import re
import logging

import numpy as np
import torch
import torch.utils.data as torch_data
from torch import nn as nn
from torch.optim import AdamW, SGD
from tqdm.auto import tqdm
from transformers import get_scheduler
import evaluation as evaluate


# Forward
def forward_model(model, batch, generative=False, verbalizer_indices=None):
    useful_keys = ['input_ids', 'attention_mask', 'token_type_ids']
    input_batch = {key : batch[key] for key in set(useful_keys) & set(batch.keys())}
    if generative:
        input_batch['labels'] = batch['labels']

    outputs = model(**input_batch)

    logits = outputs.logits
    if generative:
        logits = logits[:, 0]
        loss = outputs.loss
        if verbalizer_indices:
            logits = logits[:, verbalizer_indices]  
    else:
        loss = None
    return logits, loss


def compute_logits(model, data, batch_size, data_collator, accelerator, generative_forward=False, verbalizer_indices=None):
    model.eval()
    data_labels = data['restricted_labels'] if 'restricted_labels' in data[0] else data['labels']
    logits = torch.zeros((max(data['data_idx']) + 1, len(set(data_labels))))
    forward_dataloader = accelerator.prepare(torch_data.dataloader.DataLoader(
        data, shuffle=False, batch_size=batch_size, collate_fn=data_collator))

    with torch.no_grad():
        for batch in tqdm(forward_dataloader, disable=not accelerator.is_local_main_process):
            batch_logits, _ = forward_model(
                model, batch, generative=generative_forward, verbalizer_indices=verbalizer_indices)
            data_indices, batch_logits = accelerator.gather((batch['data_idx'], batch_logits))
            logits[data_indices.cpu()] = batch_logits.detach().type(torch.float32).cpu()
    return logits


# Optimization
def get_optimizer_grouped_parameters(model, weight_decay):
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() 
                       if not any(nd in n for nd in no_decay) and p.requires_grad],
            "weight_decay": weight_decay,  # Weight decay group
        },
        {
            "params": [p for n, p in model.named_parameters() 
                       if any(nd in n for nd in no_decay) and p.requires_grad],
            "weight_decay": 0.0,  # No weight decay group
        },
    ]
    return optimizer_grouped_parameters


def construct_optimizer(models, args):
    # Get parameter groups: some parameters may not need weight decay
    if not isinstance(models, list):
        models = [models]
    optimizer_grouped_parameters = []
    for model in models:
        optimizer_grouped_parameters += get_optimizer_grouped_parameters(model, args.weight_decay)
    
    optimizer = None
    if args.optimizer == 'adam':
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_eps)
    elif args.optimizer == 'sgd':
        optimizer = SGD(optimizer_grouped_parameters, lr=args.learning_rate, momentum=0.9)

    return optimizer


def construct_scheduler(optimizer, args):
    num_warmup_steps = int(args.warmup_rate * args.max_train_steps) if args.warmup_rate < 1 else int(args.warmup_rate)

    if args.scheduler == 'linear_with_warmup':
        scheduler_name = 'linear'
    elif args.scheduler == 'cosine_with_warmup':
        scheduler_name = 'cosine'
    else:
        scheduler_name = args.scheduler

    return get_scheduler(
        name=scheduler_name, optimizer=optimizer, 
        num_warmup_steps=num_warmup_steps, num_training_steps=int(args.max_train_steps))


def construct_optimizer_scheduler(models, args):
    optimizer = construct_optimizer(models, args)
    scheduler = construct_scheduler(optimizer, args)
    return optimizer, scheduler


# IO
def setup_logger(logger, log_filename=None):
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    if not logger.logger.hasHandlers():
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(formatter)
        logger.logger.addHandler(console_handler)
    if log_filename is not None:
        file_handler = logging.FileHandler(log_filename)
        file_handler.setFormatter(formatter)
        logger.logger.addHandler(file_handler)
    return logger


def save_model(save_path, accelerator, prefix='validation', model=None, eval_result=None, eval_logits_dict=None, train_logits=None):
    if model is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(
            save_path, 
            is_main_process=accelerator.is_main_process,
            save_function=accelerator.save,
        )
    if eval_result is not None and accelerator.is_local_main_process:
        write_json(eval_result, os.path.join(save_path, f'{prefix}_result.json'))
    if eval_logits_dict is not None and accelerator.is_local_main_process:
        for dataset_name in eval_logits_dict:
            torch.save(eval_logits_dict[dataset_name], os.path.join(save_path, f'{prefix}_{dataset_name}_logits.pt'))
    if train_logits is not None and accelerator.is_local_main_process:
        torch.save(train_logits, os.path.join(save_path, 'train_logits.pt'))


def write_json(data, filename, formatted=True):
    with open(filename, 'w') as fout:
        if formatted:
            json.dump(data, fout, indent=4, separators=(',', ':'))
        else:
            json.dump(data, fout)
    return filename


def load_json(file_path):
    with open(file_path) as fin:
        return json.load(fin)


def generate_save_name(args):
    model_name = args.model_name
    if '/' in model_name:
        model_name = model_name.replace('/', '-')
    save_dir = f'ckpt/{args.task_name}/{args.mode}/{model_name}'
    if not os.path.exists(save_dir):
        max_run_idx = -1
    else:
        max_run_idx = 0
        for model_dir in os.listdir(save_dir):
            match = re.match(r'run_(\d+)', model_dir)
            if match:
                run_idx = int(match.group(1))
                max_run_idx = run_idx if run_idx > max_run_idx else max_run_idx
    return os.path.join(save_dir, f'run_{max_run_idx + 1}')


def split_data_into_groups(dataset, label2name):
    group_dict = {}
    for confounder in np.unique(dataset['confounder']):
        for label in np.unique(dataset['labels']):
            label_name = label2name[str(int(label))]
            group_name = f'{confounder}_label_{label_name}'
            print('Extracting group:', group_name)
            group_dataset = dataset.filter(lambda x: x['confounder'] == confounder and x['labels'] == label)
            group_dataset = group_dataset.remove_columns('confounder')
            group_dict[group_name] = group_dataset
    return group_dict


def get_step_dirs(run_dir):
    step_dirs = [os.path.join(run_dir, step_dir) for step_dir in os.listdir(run_dir) 
                if step_dir.startswith('step') and os.path.isdir(os.path.join(run_dir, step_dir))]
    # sort step dirs by step number
    step_dirs = sorted(step_dirs, key=lambda x: int(x.split('/')[-1].split('_')[-1]))
    return step_dirs


def wandb_process_args(args):
    wandb_config = vars(args)

    reference_run_dir = wandb_config['reference_run_dir']
    if reference_run_dir is None:
        wandb_config['ref_model'] = None
        wandb_config['ref_size'] = None
    else:
        if 'electra' in reference_run_dir:
            wandb_config['ref_model'] = 'ELECTRA'
        elif 'deberta' in reference_run_dir:
            wandb_config['ref_model'] = 'DeBERTa'
        if 'large' in reference_run_dir:
            wandb_config['ref_size'] = 'large'
        elif 'base' in reference_run_dir:
            wandb_config['ref_size'] = 'base'
        elif 'small' in reference_run_dir:
            wandb_config['ref_size'] = 'small'

    model_name = wandb_config['model_name']
    if 'electra' in model_name:
        wandb_config['main_model'] = 'ELECTRA'
    elif 'deberta' in model_name:
        wandb_config['main_model'] = 'DeBERTa'
    if 'large' in model_name:
        wandb_config['main_size'] = 'large'
    elif 'base' in model_name:
        wandb_config['main_size'] = 'base'
    elif 'small' in model_name:
        wandb_config['main_size'] = 'small'

    return wandb_config


# Evaluation
def merge_logits(logits, merge_classes_list):
    all_merge_classes = [merge_class for merge_classes in merge_classes_list for merge_class in merge_classes]
    assert len(all_merge_classes) == len(set(all_merge_classes))  # no duplicate
    assert min(all_merge_classes) >= 0 and max(all_merge_classes) < logits.size(-1)  # valid class index
    copy_logits = logits.clone()
    for merge_classes in merge_classes_list:
        for other_class in merge_classes[1:]:
            copy_logits[:, merge_classes[0]] += copy_logits[:, other_class]
    delete_classes = set([delete_class for merge_classes in merge_classes_list for delete_class in merge_classes[1:]])
    keep_classes = list(set(range(copy_logits.size(-1))) - delete_classes)
    copy_logits = copy_logits[:, keep_classes]
    return copy_logits


def eval_acc(model, dataloader, accelerator, 
             merge_class_list=None, return_logits=False, metric_name='acc', 
             generative_forward=False, verbalizer_indices=None):
    model.eval()
    if metric_name == 'acc':
        metric = evaluate.load('accuracy')
    elif metric_name == 'macro_f1' or metric_name == 'micro_f1':
        metric = evaluate.load('f1')
    else:
        raise ValueError(f'Unsupported metric {metric_name}')
    with torch.no_grad():
        dataset_logits, dataset_data_indices = [], []
        for batch in tqdm(dataloader, disable=not accelerator.is_local_main_process):
            if generative_forward:
                batch['labels'] = batch['input_ids'][:, 0].clone().fill_(0).unsqueeze(dim=1)
            logits, _ = forward_model(model, batch, generative=generative_forward, verbalizer_indices=verbalizer_indices)
            # merge logits if needed (e.g., merge contradiction and neutral to non-entailment for HANS)
            if merge_class_list is not None:
                logits = merge_logits(logits, merge_class_list)
            predictions = logits.argmax(dim=-1)
            batch_labels = batch['restricted_labels'] if 'restricted_labels' in batch else batch['labels']
            predictions, references = accelerator.gather_for_metrics((predictions, batch_labels))
            metric.add_batch(predictions=predictions.cpu(), references=references.cpu())
            if return_logits:
                dataset_logits.append(accelerator.gather(logits))
                dataset_data_indices.append(accelerator.gather(batch['data_idx']))
    accelerator.wait_for_everyone()
    if metric_name == 'acc':
        performance = metric.compute()['accuracy']
    elif metric_name == 'macro_f1':
        performance = metric.compute(average='macro')['f1']
    elif metric_name == 'micro_f1':
        performance = metric.compute(average='micro')['f1']
    
    # concatenate logits by the order of data_idx
    if return_logits:
        dataset_logits, dataset_data_indices = \
            torch.cat(dataset_logits, dim=0).type(torch.float32).cpu(), \
            torch.cat(dataset_data_indices, dim=0).cpu()
        dataset_logits_clone = torch.zeros_like(dataset_logits)
        dataset_logits_clone[dataset_data_indices] = dataset_logits

    return performance, dataset_logits_clone if return_logits else None


# DM functions
def compute_data_map_scores(logits_list, labels):
    prob = torch.stack(logits_list, dim=0).softmax(dim=-1)
    conf = torch.gather(prob, 2, labels[(None,) + (...,) + (None,)].expand_as(prob))[:, :, 0]
    ambiguous_scores, hard_scores = torch.std(conf, dim=0), -torch.mean(conf, dim=0)
    return ambiguous_scores, hard_scores


def obtain_data_map_indices(run_dir, train_data, mode='ambiguous'):
    assert mode in ['ambiguous', 'hard']
    logits_list = list(torch.load(os.path.join(run_dir, 'logits_log.pt'), map_location='cpu').values())
    # print((logits_list[-1]==0).any())
    logits_list = logits_list[:-1]
    # print((logits_list[-1]==0).any())
    data_labels = train_data['restricted_labels'] if 'restricted_labels' in train_data[0] else train_data['labels']
    labels = torch.LongTensor(data_labels)
    ambiguous_scores, hard_scores = compute_data_map_scores(logits_list, labels)
    scores = ambiguous_scores if mode == 'ambiguous' else hard_scores
    _, data_map_indices = torch.sort(scores, descending=True)
    return data_map_indices

# Analyses helpers
# Compute weights
def read_dm_train_weights(logits_log_ref, train_labels, mode='ambiguous', filter_rate=0.33):
    # mode: 'ambiguous' or 'hard_to_learn'
    weights = torch.zeros_like(train_labels)
    logits_log_ref = torch.stack(list(logits_log_ref.values()), dim=0)
    conf_log_ref = torch.gather(logits_log_ref, 2, torch.LongTensor(
        train_labels)[(None,) + (...,) + (None,)].expand_as(logits_log_ref))[:, :, 0]
    scores = torch.std(conf_log_ref, dim=0) if mode == 'ambiguous' else -torch.mean(conf_log_ref, dim=0)
    _, indices = torch.sort(scores, descending=True)
    weights[indices[:int(indices.size(0) * filter_rate)]] = 1
    return weights


# Checkpoint loaders
def get_reference_dir(task_name, method_name, seed_idx, sep_dm=False):
    reference_dir = os.path.join('ckpt', task_name, 'erm')
    if method_name == 'poe':
        reference_dir = os.path.join(reference_dir, 'TinyBERT')
    else:
        reference_dir = os.path.join(reference_dir, 'bert-base-uncased')
    if method_name in ['jtt', 'lff', 'poe'] or (method_name == 'dm' and sep_dm):
        return os.path.join(reference_dir, f'{method_name}_reference_{seed_idx}')
    else:
        return os.path.join(reference_dir, f'erm_{seed_idx}')


def get_final_dir(task_name, method_name, seed_idx):
    final_dir = os.path.join('ckpt', task_name, method_name, 'bert-base-uncased')
    if method_name == 'erm':
        return os.path.join(final_dir, f'erm_{seed_idx}')
    else:
        return os.path.join(final_dir, f'{method_name}_final_{seed_idx}')


def get_step_dirs(run_dir):
    step_dirs = [os.path.join(run_dir, step_dir) for step_dir in os.listdir(run_dir) 
                if step_dir.startswith('step') and os.path.isdir(os.path.join(run_dir, step_dir))]
    # sort step dirs by step number
    step_dirs = sorted(step_dirs, key=lambda x: int(x.split('/')[-1].split('_')[-1]))
    return step_dirs


def load_train_logits(run_dir, mode='train'):
    if mode == 'train':
        train_logits = torch.load(os.path.join(run_dir, 'logits_log.pt'), map_location='cpu')
    elif mode == 'inference':
        step_dirs = get_step_dirs(run_dir)
        train_logits = {idx: torch.load(f'{step_dir}/train_logits.pt', map_location='cpu') 
                        for idx, step_dir in enumerate(step_dirs)}
    else:
        raise ValueError('Mode must be either train or inference')
    return train_logits


def load_config(run_dir):
    return load_json(os.path.join(run_dir, 'config.json'))
    

def load_data_indices(run_dir): 
    return torch.load(os.path.join(run_dir, 'data_indices_log.pt'), map_location='cpu')


# Compute spurious features 
def lexical_overlap(premise, hypothesis):
    shared_tokens = set(premise) & set(hypothesis)
    return len(shared_tokens) / len(hypothesis)