import torch
import torch.nn.functional as F
import copy
from specific.tensor import make_dataloader

def KGEdit(args, trainer, tokenizer, example, z_list, p_lambda):
    sep_word = '[SEP]'

    paired = zip(p_lambda,z_list)
    # sort from smallest to the largest
    sorted_pairs = sorted(paired)
    z_list_sorted = []
    for item in sorted_pairs:
        if item[0] >= p_lambda[0]:
            z_list_sorted.append(item[1])

    example_temp = copy.deepcopy(example)

    labels = example.label
    succeed = False

    removed_z = []
    added_z = []

    while z_list_sorted:
        # plus z with high conditional probability
        z = z_list_sorted[-1]
        if z not in example_temp.texts[example.label]['triples']:
            z = z_list_sorted.pop()
            added_z.append(z)
            example_temp.texts[example.label]['triples'].append(z)
            # update triple_temp
            triples = f' {sep_word} '.join([' '.join(trip) for trip in example_temp.texts[example.label]['triples']])
            example_temp.texts[example.label]['triples_temp'] = triples
            examples=[example_temp]

            dataloader = make_dataloader(
                    args.experiment, examples, tokenizer, total_batch_size=len(examples)*len(examples[0].texts),
                    drop_last=False, max_seq_length=args.max_seq_length, shuffle=False, vary_segment_id=args.vary_segment_id, config=args, dev=True)

            with torch.no_grad():
                for batch in dataloader:
                    loss, right_num, input_size, logits, adv_loss = trainer._forward(batch, None, mode='dev', dataset_name=args.data_version, return_all=True)
                    assert logits.shape[0] == len(examples)
        
            # print(f'######add relation {z}, probability {F.softmax(logits, dim=1)}, label {labels}')

            predicts = torch.argmax(logits, dim=1)

            if predicts.squeeze() == labels:
                succeed = True
                break

        if not z_list_sorted:
            break

        # minus z with low conditional probability
        z = z_list_sorted[0]
        if z in example_temp.texts[example.label]['triples']:
            z = z_list_sorted.pop(0)
            removed_z.append(z)
            example_temp.texts[example.label]['triples'].remove(z)
            # update triple_temp
            triples = f' {sep_word} '.join([' '.join(trip) for trip in example_temp.texts[example.label]['triples']])
            example_temp.texts[example.label]['triples_temp'] = triples
            examples=[example_temp]

            dataloader = make_dataloader(
                    args.experiment, examples, tokenizer, total_batch_size=len(examples)*len(examples[0].texts),
                    drop_last=False, max_seq_length=args.max_seq_length, shuffle=False, vary_segment_id=args.vary_segment_id, config=args, dev=True)

            with torch.no_grad():
                for batch in dataloader:
                    loss, right_num, input_size, logits, adv_loss = trainer._forward(batch, None, mode='dev', dataset_name=args.data_version, return_all=True)
                    assert logits.shape[0] == len(examples)
        
            # print(f'######removed relation {z}, probability {F.softmax(logits, dim=1)}, label {labels}')

            predicts = torch.argmax(logits, dim=1)

            if predicts.squeeze() == labels:
                succeed = True
                break

    if succeed:
        return example_temp, succeed, added_z, removed_z
    
    else:
        return example, succeed, added_z, removed_z
    


def KGEdit_causal(args, trainer, tokenizer, example, z_list, p_lambda, p_z, correct):
    if sum(correct) < 1:
        return example, False, [], []
    
    def _join(z):
        return " ".join(z)

    sep_word = '[SEP]'

    temp_z_list = copy.deepcopy(example.causal_text["triples"])

    dict_z_info = {}
    for i in range(len(z_list)):
        dict_z_info[_join(z_list[i])] = [p_lambda[i].item(), p_z[i].item(), correct[i].item()]

    removed_z = []
    added_z = []

    while True:
        # find the z in temp_z_list with the highest p_z. It is also the z that will be picked when only picking one z
        max_p_z = 0
        picked_z = None
        if len(temp_z_list) > 0:
            for z in temp_z_list:
                if dict_z_info[_join(z)][1] >= max_p_z:
                    max_p_z = dict_z_info[_join(z)][1]
                    picked_z = z
            # if the picked z is good, return
            if dict_z_info[_join(picked_z)][-1]:
                example.causal_text["triples"] = temp_z_list
                return example, True, added_z, removed_z
        
        # try to add a z with p_z > max_p_z and correct and highest p_lambda * p_z
        filter_vector = (p_z>=max_p_z) * correct * p_lambda * p_z
        if max(filter_vector)>0:
            z_new = z_list[torch.argmax(filter_vector)]
            added_z.append(z_new)
            temp_z_list.append(z_new)
            example.causal_text["triples"] = temp_z_list
            return example, True, added_z, removed_z
        # fail to add a z, try to remove the picked z in current list
        try:
            temp_z_list.remove(picked_z)
        except:
            print(p_lambda)
            print(p_z)
            print(correct)
            print(picked_z)
            print(filter_vector)
            exit()
        removed_z.append(picked_z)
    
        
