import torch
import torch.nn.functional as F
import copy
from specific.tensor import make_dataloader


def Eval_alter_causal(args, trainer, tokenizer, example):
    device = trainer.device
    sep_word = '[SEP]'
    z_list = example.causal_text['triples']
    examples = []
    if len(z_list) == 0 or z_list is None:
        example.causal_text['triples_temp'] = " "
        example.alter_causal_prompt(args.model_type)
        examples = [example]
    else:
        for z in z_list:
            example_temp = copy.deepcopy(example)
            example_temp.causal_text['triples'] = [z]
            # update triple_temp
            triples = f' {sep_word} '.join([' '.join(trip) for trip in example_temp.causal_text['triples']])
            example_temp.causal_text['triples_temp'] = triples
            example_temp.alter_causal_prompt(args.model_type)
            examples.append(example_temp)


    total_batch_size = len(examples)
    dataloader = make_dataloader(
            args.experiment, examples, tokenizer, total_batch_size=total_batch_size,
            drop_last=False, max_seq_length=args.max_seq_length, shuffle=False, vary_segment_id=args.vary_segment_id, config=args, dev=True)

    try:
        with torch.no_grad():
            for batch in dataloader:
                loss, right_num, input_size, output, adv_loss = trainer._forward(batch, None, mode='dev', dataset_name=args.data_version, return_all=True)
                assert output.scores[0].shape[0] == len(examples)
    except:
        return None, None, None



    # caculate p_z
    
    if len(z_list) == 0 or z_list is None:
        return right_num[0]
    p_z = []
    prefix = f"To answer the question: {example.texts[0]['question_text']} I need information {example.texts[0]['question_concept']} "
    for z in z_list:
        relation = z[1]
        prefix_lens = len(tokenizer(prefix)["input_ids"])
        prompt_tok = tokenizer(
            f"{prefix} {relation}",
            padding=True,
            return_tensors="pt",
        ).to(device)

        if "gpt2" in args.model_type:
            if relation[0] != ' ':
                relation = " " + relation
        r_tok = tokenizer.encode(relation, add_special_tokens=False)
        r_len = len(r_tok)

        with torch.no_grad():
            logits = trainer.model.model(**prompt_tok).logits[0]

        _p_z = 0

        # print(prompt_tok)
        # print(logits.shape)
        # print(r_len)
        # print(prefix_lens)
        # for j in range(r_len):
        #     cur_tok = r_tok[j]
        #     _p_z += F.softmax(
        #         logits[prefix_lens + j - 1, :], dim=0
        #     )[cur_tok].item()
        # _p_z /= r_len
        for j in range(r_len):
            cur_tok = r_tok[-j-1]
            _p_z += F.softmax(
                logits[-j - 1, :], dim=0
            )[cur_tok].item()
        _p_z /= r_len

        p_z.append(_p_z)

    p_z = torch.tensor(p_z).to(device)

    max_p_z_idx = torch.argmax(p_z)
    if right_num[max_p_z_idx].sum() > 0:
        return True
    else:
        return False