import torch
import numpy as np

def find(s, ch):
    return [i for i, ltr in enumerate(s) if ltr == ch]

 

def graph_evaluation(args, model, dataloader, tokenizer, TO_TOKEN):
    sep_value = TO_TOKEN["|"]
    eos_value = TO_TOKEN["."]
    
    global_str_accuracy = 0
    global_char_accuracy = 0
    max_test_steps = 50
    num_sequences = 0
    for step, batch in enumerate(dataloader):
        x = batch['input_ids'].to('cuda')
        num_sequences+=x.shape[0]
        if args.model == "gmlp":
            y = batch['label_ids']
            with torch.no_grad():
                logits = model(x)
        else:
            with torch.no_grad():
                logits = model(x)['logits']
        
        pred = torch.argmax(logits, dim=-1)
        for i in range(len(x)):
            if args.model == "gmlp":
                x_out=y[i].tolist()
                pred_out=pred[i].tolist()
                
                full_gt_string = tokenizer.decode(x[i])
                full_pred_string = tokenizer.decode(pred[i])

                gt = x_out
                pred_model = pred_out

            else:
                x_out=x[i].tolist()
                pred_out=pred[i].tolist()
                
                full_gt_string = tokenizer.decode(x[i])
                full_pred_string = tokenizer.decode(pred[i])

                idx_gt_sep = x_out.index(sep_value)
                idx_gt_eos =  x_out.index(eos_value)
                
                gt = x_out[idx_gt_sep+1:idx_gt_eos]

                start_idx = idx_gt_sep
                end_idx = start_idx + len(gt)
                pred_model = pred_out[start_idx:end_idx]
                
            gt_string = tokenizer.decode(torch.tensor(gt))
            pred_model_string = tokenizer.decode(torch.tensor(pred_model))

            global_str_accuracy += int(gt==pred_model)
            char_acc_tmp = 0
            for index in range(len(gt)):
                if gt[index] == pred_model[index]:
                    char_acc_tmp+= 1
            global_char_accuracy += char_acc_tmp/len(gt)

            # print("\n")
            # print(f"FULL GT {full_gt_string}")
            # print(f"FULL PRED {full_pred_string}")
            # print(f"EXTRACT GT {gt_string}")
            # print(f"EXTRACT PRED {pred_model_string}")
            # print(f"EXTRACT GT {gt}")
            # print(f"EXTRACT PRED {pred_model}")
            # print(f"EXTRACT RESULT {gt==pred_model}")
            # print("\n")


        
        #global_str_accuracy /= len(x)
        #global_char_accuracy /= len(x)
        if step >= max_test_steps:
            break
    
    return global_str_accuracy/(num_sequences), global_char_accuracy/(num_sequences)


        
               

def phonebook_evaluation(args, model, dataloader, tokenizer, TO_TOKEN):
    sep_value = TO_TOKEN["|"]
    eos_value = TO_TOKEN["."]
    
    global_str_accuracy = 0
    global_char_accuracy = 0
    max_test_steps = 50
    num_sequences = 0
    for step, batch in enumerate(dataloader):
        x = batch['input_ids'].to('cuda')
        num_sequences+=x.shape[0]
        if args.model == "gmlp":
            y = batch['label_ids']
            with torch.no_grad():
                logits = model(x)
        else:
            with torch.no_grad():
                logits = model(x)['logits']
        
        pred = torch.argmax(logits, dim=-1)
        for i in range(len(x)):
            x_out=y[i].tolist() if args.model == "gmlp" else x[i].tolist() 
            pred_out=pred[i].tolist()
            full_gt_string = tokenizer.decode(y[i]) if args.model == "gmlp" else tokenizer.decode(x[i])
            full_pred_string = tokenizer.decode(pred[i])


            if args.model == "gmlp":
                gt = x_out
                pred_model = pred_out 
            else:
                idx_gt_sep = x_out.index(sep_value)
                idx_gt_eos =  x_out.index(eos_value)

                gt = x_out[idx_gt_sep+1:idx_gt_eos]
                start_idx = idx_gt_sep
                end_idx = start_idx + len(gt)
                pred_model = pred_out[start_idx:end_idx]

            pred_model_string = tokenizer.decode(torch.tensor(pred_model))
            gt_string = tokenizer.decode(torch.tensor(gt))

            global_str_accuracy += int(gt==pred_model)
            char_acc_tmp = 0
            for index in range(len(gt)):
                if gt[index] == pred_model[index]:
                    char_acc_tmp+= 1
            global_char_accuracy += char_acc_tmp/len(gt)



            # print("\n")
            # print(f"FULL GT {full_gt_string}")
            # print(f"FULL PRED {full_pred_string}")
            # print(f"EXTRACT GT {gt_string}")
            # print(f"EXTRACT PRED {pred_model_string}")
            # print(f"EXTRACT GT {gt}")
            # print(f"EXTRACT PRED {pred_model}")
            # print(f"EXTRACT RESULT {gt==pred_model}")
            # print("\n")


        
        if step >= max_test_steps:
            break
    
    return global_str_accuracy/(num_sequences), global_char_accuracy/(num_sequences)















