import random
from typing import Any, List, Union, Dict

from tqdm import tqdm

import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from torch.nn.functional import softmax
from torch.utils.data import Dataset, DataLoader
from transformers import DataCollatorForLanguageModeling

from data import load_split_dataset, TensorDataset, PEFTDataset, ourPEFTDataset
from examples import get_examples
from templates import get_templates
from train_peft import OurDataCollatorForLanguageModeling


# @torch.inference_mode()
def get_loss_label(generator, batch, list_labels, labels_loss=False, args=None):
    model = generator.model.eval()
    loss_fct = CrossEntropyLoss(reduction='none', ignore_index=-100)
    # batch['extended_ICL_mask'] = extended_input_mask
    # batch['extended_attention_mask'] = extended_attention_mask

    list_all_label_tok = [i[0] for i in generator.tokenizer(list_labels, add_special_tokens=False).data['input_ids']]
    tensor_all_label_tok = torch.Tensor(list_all_label_tok).long()
    unique_all_label_tokenize = np.unique(list_all_label_tok)
    assert len(unique_all_label_tokenize) == len(list_all_label_tok)

    input_ids = batch['input_ids'].to(model.device)
    # label_tokens = batch['token_type_ids']
    # ICL_mask = batch['extended_ICL_mask'].to(model.device)
    # loss_mask = ICL_mask == ICL_mask.max()
    sample_mask = batch['sample_mask'].to(model.device)
    labels = batch['labels'].to(model.device)
    # loss_mask = label_tokens == 1

    with torch.autocast(device_type='cuda', dtype=torch.float16):
        # with torch.no_grad():
        # if args.method_boost_type in ['ours']:
        #     outputs = model(input_ids=input_ids, ICL_mask=ICL_mask)
        if args.method_boost_type in ['IPT', 'LORA', 'prompt_tuning', 'prefix_tuning', 'our_prompt_tuning']:
            if args.method_boost_type in ['our_prompt_tuning', 'IPT']:
                outputs = model(input_ids=input_ids, labels=labels, sample_mask=sample_mask)
                loss_mask = outputs['extended_input_mask'] == outputs['extended_input_mask'].max()
                labels = outputs['gt_labels']
            elif args.method_boost_type in ['LORA', 'prompt_tuning', 'prefix_tuning']:
                outputs = model(input_ids=input_ids, labels=labels)
                outputs.logits = outputs.logits[..., -labels.shape[1]:, :]  # TODO should we always cut the output?
                loss_mask = sample_mask == sample_mask.max()
        elif args.method_boost_type in ['none']:
            outputs = model(input_ids=input_ids)
            loss_mask = sample_mask == sample_mask.max()
        else:
            outputs = model(input_ids=input_ids)

    batch_size = input_ids.shape[0]
    # we cant use forward to predict, we need generate...
    logits = outputs.logits[..., :-1, :].contiguous().to(model.dtype)
    shift_labels = labels[..., 1:].contiguous().to(logits.device)
    loss_mask = loss_mask[..., 1:]
    # losses = loss_fct(logits.view(-1, logits.size(-1)), shift_labels.view(-1)).view(batch_size, -1)

    list_first_label_adapted_loss = []
    list_loss = []
    list_first_label_loss = []
    list_prob = []
    for i in range(batch_size):
        device = logits.device
        first_label_logit = logits[i][loss_mask[i]][0]
        selected_enteries = first_label_logit[tensor_all_label_tok]
        list_prob.append(torch.nn.Softmax()(selected_enteries).detach().cpu())

        first_label = shift_labels[i][loss_mask[i]][0]
        adapted_first_label = torch.where(tensor_all_label_tok.to(device=device) == first_label)[0]
        first_adapted_loss = loss_fct(selected_enteries.view(-1, len(tensor_all_label_tok)), adapted_first_label)

        loss_label = shift_labels[i][loss_mask[i]]
        loss_logits = logits[i][loss_mask[i]]
        all_loss = loss_fct(loss_logits.view(-1, logits.size(-1)),
                            loss_label.view(len(loss_logits), )).mean()

        first_label_loss = loss_fct(loss_logits[0].view(-1, logits.size(-1)),
                                    loss_label[0].view(1, ))

        list_loss.append(all_loss.detach().cpu())
        list_first_label_loss.append(first_label_loss.detach().cpu())
        list_first_label_adapted_loss.append(first_adapted_loss.detach().cpu())

    tensor_samlpes_losses1 = torch.tensor(list_loss)
    tensor_samlpes_losses2 = torch.tensor(list_first_label_loss)
    tensor_samlpes_losses3 = torch.tensor(list_first_label_adapted_loss)

    tensor_samlpes_probs = torch.cat(list_prob)

    return tensor_samlpes_losses1, tensor_samlpes_losses2, tensor_samlpes_losses3, tensor_samlpes_probs


def predict_label(generator, eval_dataset, labels, batch_size=1, method='direct', labels_loss=False,
                  calibrate_dataset=None, mode='diagonal_W', args=None):
    collator = get_collator(generator.tokenizer, args=args)
    correction_factor = None
    eval_dataloader = DataLoader(
        eval_dataset,
        shuffle=False,
        batch_size=batch_size,
        collate_fn=OurDataCollatorForLanguageModeling(generator.tokenizer, training=False, args=args, mlm=False)
    )
    list_losses1 = []
    list_losses2 = []
    list_losses3 = []
    list_probs = []
    for batch in tqdm(eval_dataloader):
        tensor_samlpes_losses1, tensor_samlpes_losses2, tensor_samlpes_losses3, tensor_probs = get_loss_label(generator,
                                                                                                              batch,
                                                                                                              labels,
                                                                                                              labels_loss,
                                                                                                              args)
        list_losses1.append(tensor_samlpes_losses1)
        list_losses2.append(tensor_samlpes_losses2)
        list_losses3.append(tensor_samlpes_losses3)
        list_probs.append(tensor_probs)

    losses1 = torch.cat(list_losses1, dim=0)
    losses2 = torch.cat(list_losses2, dim=0)
    losses3 = torch.cat(list_losses3, dim=0)

    probs = torch.cat(list_probs, dim=0).reshape(-1, len(labels))
    return np.array(labels)[probs.argmax(dim=1)], probs, losses1, losses2, losses3


# @torch.inference_mode()
def get_loss(generator, batch, labels_loss=False, args=None):
    model = generator.model
    loss_fct = CrossEntropyLoss(reduction='none', ignore_index=-100)

    input_ids = batch['input_ids'].to(model.device)
    attention_mask = batch['attention_mask'].to(model.device)
    label_tokens = batch['token_type_ids']
    labels = torch.where(attention_mask == 1, input_ids, -100)
    ICL_mask = batch['ICL_mask'].to(model.device)
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        if args.method_boost_type in ['ours']:
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, ICL_mask=ICL_mask)
        else:
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        # we cant use forward to predict, we need generate...
    logits = outputs.logits[..., :-1, :].contiguous().to(model.dtype)
    shift_labels = labels[..., 1:].contiguous().to(logits.device)
    losses = loss_fct(logits.view(-1, logits.size(-1)), shift_labels.view(-1))
    losses = losses.view(logits.size(0), logits.size(1))
    if labels_loss:
        label_mask = label_tokens[..., 1:].contiguous().to(model.device)
        losses = losses * label_mask
        losses = losses.sum(dim=-1) / label_mask.sum(dim=-1)
    else:
        losses = losses.mean(dim=-1)
    losses = losses.detach().cpu()

    return losses


def classify(losses, labels, correction_factor=None, mode="diagonal_W"):
    """this function applies a correction factor from the calibrate method to the model's predicted distribution"""
    num_classes = len(labels)
    if correction_factor is None:
        # do not calibrate
        W = torch.eye(num_classes, dtype=losses.dtype)
        b = torch.zeros(num_classes, dtype=losses.dtype)
    else:
        # calibrate
        if mode == "diagonal_W":
            W = torch.linalg.inv(torch.eye(num_classes, dtype=losses.dtype) * correction_factor)
            b = torch.zeros(num_classes, dtype=losses.dtype)
        elif mode == "identity_W":
            W = torch.eye(num_classes)
            b = -1 * correction_factor[:, None]
        else:
            raise NotImplementedError(f"{mode} is not implemented for calibration")

    uncalibrated_probs = softmax(-losses)
    calibrated_probs = torch.matmul(uncalibrated_probs, W) + b

    return np.array(labels)[calibrated_probs.argmax(1)], calibrated_probs


class OursDataCollator(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, args):
        super(OursDataCollator, self).__init__(tokenizer, mlm=False)
        self.args = args

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        key_ICL = 'ICL_mask'
        list_icl_examples = []
        for example in examples:
            if key_ICL in example.keys():
                list_icl_examples.append(example.pop(key_ICL))

        ret_dict = super(OursDataCollator, self).torch_call(examples)

        if len(list_icl_examples) > 0:
            batch_size, max_len = ret_dict.data['input_ids'].size()
            tensor_icl = torch.zeros(batch_size, max_len).long()
            for i, example in enumerate(list_icl_examples):
                example_len = len(example)
                # for j in range(1,max(example)+1):
                #     print(self.tokenizer.batch_decode(ret_dict.data['input_ids'][0,torch.where(example==j)[0]]))
                tensor_icl[i, :example_len] = example

            ret_dict.data[key_ICL] = tensor_icl

        return ret_dict


def get_collator(tokenizer, args):
    if args.collate_fn_name in 'standard':
        collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
    elif args.collate_fn_name in 'ours':
        collator = OursDataCollator(tokenizer, args)

    return collator


def predict(generator, eval_dataset, labels, batch_size=1, method='direct', labels_loss=False,
            calibrate_dataset=None, mode='diagonal_W', args=None):
    collator = get_collator(generator.tokenizer, args=args)

    if method == 'calibrate':
        calibrate_dataloader = DataLoader(calibrate_dataset,
                                          shuffle=False,
                                          batch_size=batch_size,
                                          collate_fn=collator)
        # get probability distribution for context-free inputs
        cf_losses = []
        for batch in tqdm(calibrate_dataloader):
            cf_losses.extend(get_loss(generator, batch, labels_loss, args))
        cf_losses = torch.tensor(cf_losses, dtype=torch.float32).reshape(-1, len(labels))
        cf_label_probs = softmax(-cf_losses)
        # calculate calibration correction term
        correction_factor = torch.mean(cf_label_probs, dim=0)
    else:
        correction_factor = None

    eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=batch_size, collate_fn=collator)
    losses = []
    for batch in tqdm(eval_dataloader):
        losses.extend(get_loss(generator, batch, labels_loss, args))

    losses = torch.tensor(losses, dtype=torch.float32).reshape(-1, len(labels))
    results, probs = classify(losses, labels, correction_factor, mode)

    return results, probs


def eval_seq_languge(dataset, generator, seed, template, num_shots, selection_method,
                     ICL_mask=None, ICL_ids=None,
                     example_ids=None, examples_path=None,
                     prediction_method='direct',
                     labels_loss=False,
                     batch_size=16,
                     cache_dir=None,
                     args=None,
                     few_shot_train=None
                     ):
    train, val, labels_mp = load_split_dataset(dataset, cache_dir=cache_dir, args=args, tokenizer=generator.tokenizer)

    list_all_label_tok = [i[0] for i in
                          generator.tokenizer(list(labels_mp.values()), add_special_tokens=False).data['input_ids']]
    tensor_all_label_tok = torch.Tensor(list_all_label_tok).long()
    unique_all_label_tokenize = np.unique(list_all_label_tok)
    assert len(unique_all_label_tokenize) == len(list_all_label_tok)
    tensor_all_label_tok = tensor_all_label_tok.sort()[0]
    if args.use_train_for_eval:
        assert few_shot_train is not None
        val = few_shot_train

    eval_dataset = ourPEFTDataset(
        [(val.iloc[i]['input'].strip(), val.iloc[i]['target'].strip()) for
         i in range(len(val))],
        generator.tokenizer, [], template, args=args, bool_is_test=False)

    eval_dataloader = DataLoader(
        eval_dataset,
        shuffle=False,
        batch_size=batch_size,
        collate_fn=OurDataCollatorForLanguageModeling(generator.tokenizer, training=False, args=args, mlm=False)
    )
    list_losses = []
    list_pred = []
    list_acc = []
    model = generator.model.eval()
    loss_fct = CrossEntropyLoss(reduction='none', ignore_index=-100)

    for batch in tqdm(eval_dataloader):

        input_ids = batch['input_ids'].to(model.device)
        sample_mask = batch['sample_mask'].to(model.device)
        labels = batch['labels'].to(model.device)

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            if args.method_boost_type in ['our_prompt_tuning', 'IPT']:
                outputs = model(input_ids=input_ids, labels=labels, sample_mask=sample_mask)
                loss_mask = outputs['extended_input_mask'] == outputs['extended_input_mask'].max()
                labels = outputs['gt_labels']
            elif args.method_boost_type in ['LORA', 'prompt_tuning', 'prefix_tuning']:
                outputs = model(input_ids=input_ids, labels=labels)
                outputs.logits = outputs.logits[..., -labels.shape[1]:, :]  # TODO should we always cut the output?
                loss_mask = sample_mask == sample_mask.max()

        batch_size = input_ids.shape[0]
        # we cant use forward to predict, we need generate...
        logits = outputs.logits[..., :-1, :].contiguous().to(model.dtype)
        shift_labels = labels[..., 1:].contiguous().to(logits.device)
        loss_mask = loss_mask[..., 1:]
        device = logits.device
        for i in range(batch_size):
            selected_label = torch.where(tensor_all_label_tok.to(device) == shift_labels[loss_mask][0])[0]
            selected_logits = logits[loss_mask][0][tensor_all_label_tok]
            list_acc.append(int(selected_logits.argmax() == selected_label))
            list_pred.append(0)
            list_losses.append(0)

    print(f'Test ACC: {np.mean(list_acc)} Loss avg: {np.mean(list_losses)} Pred avg: {np.mean(list_pred)}')

    return {"score": np.mean(list_acc), "probs": np.mean(list_pred), "predicts": 0, "example_ids": 0,
            'avg_loss_true_label': np.mean(list_losses), 'avg_loss_true_label1': 0, 'avg_loss_true_label2': 0,
            'avg_loss_true_label3': 0}


def evaluate_setup(dataset, generator, seed, template, num_shots, selection_method,
                   ICL_mask=None, ICL_ids=None,
                   example_ids=None, examples_path=None,
                   prediction_method='direct',
                   labels_loss=False,
                   batch_size=16,
                   cache_dir=None,
                   args=None,
                   few_shot_train=None
                   ):
    train, val, labels_mp = load_split_dataset(dataset, cache_dir=cache_dir, args=args)
    if args.use_train_for_eval:
        assert few_shot_train is not None
        val = few_shot_train

    if args.hpt:
        cur_size = val.shape[0]
        val = val.iloc[:int(cur_size * 0.2)]

    labels = list(labels_mp.values())

    if args.method_boost_type in ['IPT', 'LORA', 'prompt_tuning', 'prefix_tuning', 'our_prompt_tuning']:
        templates = get_templates(dataset, num_shots, args, args.num_templates, args.templates_path,
                                  args.template_seed)
        if args.use_different_template_for_eval:
            # random.seed(args.seed[0])
            bool_picked_template = False
            while bool_picked_template == False:
                eval_template = templates[random.randint(0, len(templates) - 1)]
                bool_picked_template = str(eval_template) == str(template)
            template = eval_template
            print('Changed eval format')
        eval_dataset = ourPEFTDataset(
            [(val.iloc[i]['input'].strip(), val.iloc[i]['target'].strip()) for
             i in range(len(val))],
            generator.tokenizer, labels, template, args=args, bool_is_test=False)
    else:
        selected_examples = get_examples(dataset, train, selection_method, seed, num_shots,
                                         example_ids=example_ids,
                                         examples_path=examples_path,
                                         )
        examples, example_ids = selected_examples["examples"], selected_examples["example_ids"]

        eval_dataset = TensorDataset(
            [x.strip() for x in val['input']],
            generator.tokenizer, labels, template,
            examples=examples,
            method=prediction_method,
        )
    calibrate_dataset = None
    # results, probs = predict(generator, eval_dataset, labels, batch_size=batch_size, method=prediction_method,
    #                          labels_loss=labels_loss, calibrate_dataset=calibrate_dataset, args=args)
    results, probs, losses1, losses2, losses3 = predict_label(generator, eval_dataset, labels, batch_size=batch_size,
                                                              method=prediction_method,
                                                              labels_loss=labels_loss,
                                                              calibrate_dataset=calibrate_dataset, args=args)
    score = (results == val['target']).mean()

    print(
        f'Test ACC: {score} Loss1 avg: {torch.mean(losses1)}, Loss2 avg: {torch.mean(losses2)} Loss3 avg: {torch.mean(losses3)}')
    return {"score": score, "probs": 0, "predicts": results, "example_ids": example_ids,
            'avg_loss_true_label1': torch.mean(losses1),
            'avg_loss_true_label2': torch.mean(losses2),
            'avg_loss_true_label3': torch.mean(losses3)}


def evaluate(val, eval_dataset, labels, generator,
             example_ids=None,
             prediction_method='direct',
             labels_loss=False,
             batch_size=16,
             cache_dir=None,
             args=None
             ):
    results, probs = predict_label(generator, eval_dataset, labels, batch_size=batch_size, method=prediction_method,
                                   labels_loss=labels_loss, calibrate_dataset=None, args=args)
    score = (results == val['target']).mean()
    print(f'Test ACC: {score}')
    return {"score": score, "probs": probs, "predicts": results, "example_ids": example_ids}
