import torch
import torch.nn.functional as F
import copy
from specific.tensor import make_dataloader

def Reason_triples(args, trainer, tokenizer, example, z_list):
    examples = []
    sep_word = '[SEP]'
    if z_list is None:
        examples = [example]
    else:
        for z in z_list:
            if args.causal:
                example_temp = copy.deepcopy(example)
                example_temp.causal_text['triples'] = [z]
                # update triple_temp
                triples = f' {sep_word} '.join([' '.join(trip) for trip in example_temp.causal_text['triples']])
                example_temp.causal_text['triples_temp'] = triples
                examples.append(example_temp)
            else:
                example_temp = copy.deepcopy(example)
                example_temp.texts[example.label]['triples'] = [z]
                # update triple_temp
                triples = f' {sep_word} '.join([' '.join(trip) for trip in example_temp.texts[example.label]['triples']])
                example_temp.texts[example.label]['triples_temp'] = triples
                examples.append(example_temp)

    total_batch_size = len(examples) if args.causal else len(examples)*len(examples[0].texts)
    dataloader = make_dataloader(
            args.experiment, examples, tokenizer, total_batch_size=total_batch_size,
            drop_last=False, max_seq_length=args.max_seq_length, shuffle=False, vary_segment_id=args.vary_segment_id, config=args, dev=True)


    with torch.no_grad():
        for batch in dataloader:
            loss, right_num, input_size, logits, adv_loss = trainer._forward(batch, None, mode='dev', dataset_name=args.data_version, return_all=True)
            assert logits.shape[0] == len(examples)
            print(right_num)
    
    if args.causal:
        label = example.causal_label
    else:
        label = example.label
    p_lambda = logits[:,label]

    if torch.argmax(logits[0]).squeeze() == label:
        return None

    # for i in range(len(z_list)):
    #     print(f'^^^^^^^^relation: {z_list[i]}, probability: {F.softmax(logits[i])}')

    # print(f'######original relation: {z_list[0]}, probability: {F.softmax(logits[0])}, label {label}')
    # print(f'***********succeed: {torch.argmax(logits[0]).squeeze() == label}')

    return p_lambda


def Reason_triples_causal(args, trainer, tokenizer, example, z_list):
    device = trainer.device
    examples = []
    sep_word = '[SEP]'
    if z_list is None:
        examples = [example]
    else:
        for z in z_list:
            example_temp = copy.deepcopy(example)
            example_temp.causal_text['triples'] = [z]
            # update triple_temp
            triples = f' {sep_word} '.join([' '.join(trip) for trip in example_temp.causal_text['triples']])
            example_temp.causal_text['triples_temp'] = triples
            example_temp.update_causal_prompt(args.model_type)
            examples.append(example_temp)


    total_batch_size = len(examples)
    dataloader = make_dataloader(
            args.experiment, examples, tokenizer, total_batch_size=total_batch_size,
            drop_last=False, max_seq_length=args.max_seq_length, shuffle=False, vary_segment_id=args.vary_segment_id, config=args, dev=True)

    try:
        with torch.no_grad():
            for batch in dataloader:
                loss, right_num, input_size, output, adv_loss = trainer._forward(batch, None, mode='dev', dataset_name=args.data_version, return_all=True)
                assert output.scores[0].shape[0] == len(examples)
                print(f'***********right_num: {right_num}')
    except:
        return None, None, None

    # caculate p_lambda
    # label = example.causal_label

    # scores = output['scores']

    # all_logits = torch.stack(scores, dim=0) # LxBxO

    # all_logits_transposed = all_logits.transpose(0, 1)  # Now it's BxLxO
    # p_lambda = []
    # for b in range(all_logits_transposed.size(0)):
    #     seq_logits = all_logits_transposed[b]
    #     seq_ids = output['sequences'][b]
    #     index = (seq_ids == tokenizer.sep_token_id).nonzero(as_tuple=True)[0]
    #     if index.size(0)>0:
    #         answer_len = index[0].item()
    #     else:
    #         answer_len = seq_logits.shape[0]
    #     seq_prob = F.softmax(seq_logits, dim = -1)
    #     max_p_lambda = 0
    #     for i in range(answer_len-len(label)+1):
    #         _p_lambda = 0
    #         for j in range(len(label)):
    #             _p_lambda += seq_prob[i+j, label[j]].item()
    #         if _p_lambda > max_p_lambda:
    #             max_p_lambda = _p_lambda
    #     p_lambda.append(max_p_lambda/len(label))
    
    # p_lambda = torch.tensor(p_lambda).to(device)

    # caculate p_lambda
    p_lambda = []
    
    for i in range(len(z_list)):
        z = z_list[i]
        prefix = examples[i].causal_prompt
        object = z[2]
        prefix_lens = len(tokenizer(prefix)["input_ids"])
        prompt_tok = tokenizer(
            f"{prefix} {object}",
            padding=True,
            return_tensors="pt",
        ).to(device)

        if 'gpt2' in args.model_type or "llama3" in args.model_type:
            object = " " + object
        o_tok = tokenizer.encode(object, add_special_tokens=False)
        o_len = len(o_tok)

        with torch.no_grad():
            logits = trainer.model.model(**prompt_tok).logits[0]

        _p_lambda = 0



        # for j in range(o_len):
        #     cur_tok = o_tok[j]
        #     _p_lambda += F.softmax(
        #         logits[prefix_lens + j - 1, :], dim=0
        #     )[cur_tok].item()
        # _p_lambda /= o_len

        for j in range(o_len):
            cur_tok = o_tok[-j-1]
            _p_lambda += F.softmax(
                logits[-j - 1, :], dim=0
            )[cur_tok].item()
        _p_lambda /= o_len

        p_lambda.append(_p_lambda)

    p_lambda = torch.tensor(p_lambda).to(device)



    # caculate p_z
    p_z = []
    prefix = f"To answer the question: {example.texts[0]['question_text']} I need information {example.texts[0]['question_concept']}"
    for z in z_list:
        relation = z[1]
        prefix_lens = len(tokenizer(prefix)["input_ids"])
        prompt_tok = tokenizer(
            f"{prefix} {relation}",
            padding=True,
            return_tensors="pt",
        ).to(device)

        if "gpt2" in args.model_type or "llama3" in args.model_type:
            if relation[0] != ' ':
                relation = " " + relation
        r_tok = tokenizer.encode(relation, add_special_tokens=False)
        r_len = len(r_tok)

        with torch.no_grad():
            logits = trainer.model.model(**prompt_tok).logits[0]

        _p_z = 0

        # print(prompt_tok)
        # print(logits.shape)
        # print(r_len)
        # print(prefix_lens)
        # for j in range(r_len):
        #     cur_tok = r_tok[j]
        #     _p_z += F.softmax(
        #         logits[prefix_lens + j - 1, :], dim=0
        #     )[cur_tok].item()
        for j in range(r_len):
            cur_tok = r_tok[-j-1]
            _p_z += F.softmax(
                logits[-j - 1, :], dim=0
            )[cur_tok].item()
        _p_z /= r_len

        p_z.append(_p_z)

    p_z = torch.tensor(p_z).to(device)


    # for i in range(len(z_list)):
    #     print(f'^^^^^^^^relation: {z_list[i]}, probability: {F.softmax(logits[i])}')

    # print(f'######original relation: {z_list[0]}, probability: {F.softmax(logits[0])}, label {label}')
    # print(f'***********succeed: {torch.argmax(logits[0]).squeeze() == label}')

    return p_lambda, p_z, right_num