import numpy as np
import torchmetrics
from torchmetrics.classification import BinaryAccuracy
import numpy as np
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import T5Tokenizer,T5ForConditionalGeneration,RobertaForSequenceClassification,RobertaTokenizer,RobertaTokenizerFast, AutoTokenizer, GPT2Model,GPT2LMHeadModel, GPT2Tokenizer
from transformers import AutoConfig,RobertaConfig
# test123123
import pytorch_lightning as pl
from common import *
import os
import copy
import accelerate
from tokenizers import AddedToken
from compositional_t5 import Compositional_T5

class CustomTokenizer(T5Tokenizer):
    def __init__(self, vocab_file, **kwargs):
        super().__init__(vocab_file=vocab_file, **kwargs)

        # Set your custom vocabulary
        self.add_tokens(['1', '0','~' ,'+', '*','(',')'])

        # Adjust the vocab size
        self.vocab_size = len(self)
        print("vocab_size",self.vocab_size)

class BooleanModel(pl.LightningModule):
    def __init__(
            self,
            task: str,
            composition: bool,
            model_type: str,
            model_name: str,
            warmup_steps: int,
            lr: float,
            max_input_len: int,
            num_beams: int,
            max_num_steps: int,
            method: str,
            num_layers: int,
            d_model: int,
            d_ff: int,
            num_heads: int,
            tokenizer_type: str,
            )  -> None:

        super().__init__()
        self.save_hyperparameters()  # save all hyperparameters, if you want save some part of model,
        self.lr=lr
        self.warmup_steps=warmup_steps
        self.max_input_len=max_input_len
        self.method=method
        self.num_layers=num_layers
        self.d_model=d_model
        self.d_ff=d_ff
        self.num_heads=num_heads
        self.model_name=model_name
        self.model_type=model_type
        self.tokenizer_type=tokenizer_type
        self.composition=composition
        """
        self.tokenizer = CustomTokenizer(vocab_file='./vocab.txt')
        config = AutoConfig.from_pretrained(model_name)
        config.vocab_size=len(self.tokenizer)
        self.model = T5ForConditionalGeneration(config=config)
        #self.model.resize_token_embeddings(len(self.tokenizer))
        """

        #special_tokens_dict = {'additional_special_tokens': ['0','1','~','+','*','(',')']}

        #self.tokenizer.add_special_tokens({"additional_special_tokens": ["0"]})
        #self.tokenizer.add_special_tokens({"additional_special_tokens": ["1"]})
        #self.tokenizer.add_special_tokens({"additional_special_tokens": ["~"]})
        #self.tokenizer.add_special_tokens({"additional_special_tokens": ["+"]})
        #self.tokenizer.add_special_tokens({"additional_special_tokens": ["-"]})
        #self.tokenizer.add_special_tokens({"additional_special_tokens": ["("]})
        #self.tokenizer.add_special_tokens({"additional_special_tokens": [")"]})

        #self.tokenizer.add_tokens("0")
        #self.tokenizer.add_tokens("1")
        #self.tokenizer.add_tokens("~")
        #self.tokenizer.add_tokens("+")
        #self.tokenizer.add_tokens("*")
        #self.tokenizer.add_tokens("(")
        #self.tokenizer.add_tokens(")")

        #print(len(self.tokenizer))
        print(self.composition)
        if self.composition:
            print("composition")
            if method=='un-pretrained':
               # self.tokenizer = T5Tokenizer.from_pretrained("./my_model2.model")
                if 'custom' in tokenizer_type:
                    self.tokenizer = T5Tokenizer.from_pretrained("/userhomes/Boolean/src/vocab_100_sentencepiece.model")
                elif 'character' in tokenizer_type:
                    self.tokenizer = T5Tokenizer.from_pretrained("/app/input/dataset/boolean2/vocab_character_sentencepiece.model")
		    #self.tokenizer = T5Tokenizer.from_pretrained("/userhomes/Boolean/src/vocab_character_sentencepiece.model")    
		    # self.tokenizer = T5Tokenizer.from_pretrained("/app/input/dataset/boolean2/vocab_character_sentencepiece.model")
                config = AutoConfig.from_pretrained(self.model_name)
                config.num_layers = self.num_layers
                config.d_model = self.d_model
                config.d_ff = self.d_ff
                config.num_heads = self.num_heads
                # config.vocab_size=10
                self.model = Compositional_T5(config=config)
                self.model.resize_token_embeddings(len(self.tokenizer))



        else:
            if't5' in self.model_name:
                if method=='un-pretrained':
                    #self.tokenizer = T5Tokenizer.from_pretrained("./my_model2.model")
                    if 'custom' in tokenizer_type:
                        self.tokenizer = T5Tokenizer.from_pretrained( "/userhomes/Boolean/src/vocab_100_sentencepiece.model")
                    elif 'character' in tokenizer_type:
                        self.tokenizer = T5Tokenizer.from_pretrained("/app/input/dataset/boolean2/vocab_character_sentencepiece.model")
                    elif 'depth2' in  tokenizer_type:
                        self.tokenizer = T5Tokenizer.from_pretrained("/userhomes/Boolean/src/vocab_depth2_sentencepiece.model")
                    elif 'atomic' in tokenizer_type:
                        self.tokenizer = T5Tokenizer.from_pretrained("/userhomes/Boolean/src/vocab_atomic_sentencepiece.model")
                    elif 'pretrained' in tokenizer_type:
                        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
                        self.tokenizer.add_tokens(AddedToken("~", single_word=True))

                    #print(len(self.tokenizer))
                    config = AutoConfig.from_pretrained(model_name)
                    config.num_layers = self.num_layers
                    config.d_model=self.d_model
                    config.d_ff=self.d_ff
                    config.num_heads=self.num_heads
                    #config.vocab_size=10
                    self.model = T5ForConditionalGeneration(config=config)
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    if 'character' in tokenizer_type:
                        self.model.config.eos_token_id=2
                        self.model.config.unk_token_id = 0
                        self.model.config.pad_token_id = 125

                else:
                    self.tokenizer = T5Tokenizer.from_pretrained(model_name,use_fast=False)
                    self.tokenizer.add_tokens(AddedToken("~", single_word=True))
                    #print(len(self.tokenizer))
                    #config.vocab_size=10
                    self.model = T5ForConditionalGeneration.from_pretrained(model_name)
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    config = self.model.config
                    config.num_layers = self.num_layers
                    config.d_model = self.d_model
                    config.d_ff = self.d_ff
                    config.num_heads = self.num_heads
                    # Create a new instance of the T5 model using the modified configuration
                    #self.model = T5ForConditionalGeneration(config=config)

                    # Load the pre-trained weights into the new model
                    #self.model.load_state_dict(model.state_dict())
                    if 'character' in tokenizer_type: #or 'atomic' in tokenizer_type or 'depth2' in tokenizer_type
                        self.tokenizer = T5Tokenizer.from_pretrained("/userhomes/Boolean/src/vocab_character_sentencepiece.model")
                        self.model.config.eos_token_id = 2
                        self.model.config.unk_token_id = 0
                        self.model.config.pad_token_id = 125


            elif 'roberta' in self.model_name:
                if method=='un-pretrained':
                    custom_vocab_file = "/app/input/dataset/boolean2/roberta_vocab2.json"
                    custom_merges_file = "/app/input/dataset/boolean2/roberta_merges2.txt"
                    self.tokenizer = RobertaTokenizer(vocab_file=custom_vocab_file,merges_file=custom_merges_file)
                    #if 'custom' in tokenizer_type:
                        #self.tokenizer = PreTrainedTokenizerFast(tokenizer_file='/userhomes/Boolean/src/basic_BPEtokenizer.json')
                    #elif 'pretrained' in tokenizer_type:
                        #self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
                    #print(len(self.tokenizer))
                    config = AutoConfig.from_pretrained(model_name)
                    config.num_hidden_layers = self.num_layers
                    config.hidden_size=self.d_model
                    config.intermediate_size=self.d_ff
                    config.num_attention_heads=self.num_heads
                    config.num_labels=2
                    #config.vocab_size=10
                    self.model = RobertaForSequenceClassification(config=config)
                    self.model.resize_token_embeddings(len(self.tokenizer))

                else:
                    self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
                    #print(len(self.tokenizer))
                    #config.vocab_size=10
                    #new_config = RobertaConfig.from_pretrained(model_name, num_labels=2,num_hidden_layers = self.num_layers,hidden_size=self.d_model,intermediate_size=self.d_ff,num_attention_heads=self.num_heads)

                    model = RobertaForSequenceClassification.from_pretrained(model_name)
                    config = model.config
                    config.num_hidden_layers = self.num_layers
                    #self.model.config.hidden_size=self.d_model
                    config.intermediate_size=self.d_ff
                    config.num_attention_heads=self.num_heads
                    config.num_labels=2

                    #self.model = RobertaForSequenceClassification.from_pretrained(model_name,config=config)
                    #self.model = RobertaForSequenceClassification.from_pretrained(model_name,config=config,ignore_mismatched_sizes=True)

                    # Load the pre-trained weights into the new model
                    #self.model.load_state_dict(model.state_dict(), strict=False)
                    self.model=model




            elif 'gpt' in self.model_name:
                if method == 'un-pretrained':
                    custom_vocab_file = "/app/input/dataset/boolean2/gpt2_vocab.json"		#"/userhomes/Boolean/src/gpt2_vocab.json"	("/app/input/dataset/boolean2/vocab_character_sentencepiece.model")
                    custom_merges_file = "/app/input/dataset/boolean2/gpt2_merges.txt"		#"/userhomes/Boolean/src/gpt2_merges.txt"	
                    self.tokenizer = GPT2Tokenizer(vocab_file=custom_vocab_file,merges_file=custom_merges_file)
                    #if 'custom' in tokenizer_type:
                    #    self.tokenizer = PreTrainedTokenizerFast(tokenizer_file='./basic_BPEtokenizer.json')
                    #elif 'pretrained' in tokenizer_type:
                     #   self.tokenizer = AutoTokenizer.from_pretrained(model_name)

                    #print(len(self.tokenizer))
                    config = GPT2LMHeadModel.from_pretrained(model_name).config

                    self.model = GPT2LMHeadModel(config=config)
                    self.model.resize_token_embeddings(len(self.tokenizer))
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                    #self.model.config.max_length=256

                else:

                    self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                    if 'xl' in self.model_name:
                        self.model = GPT2LMHeadModel.from_pretrained(model_name,low_cpu_mem_usage=True, trust_remote_code=True, device_map= "auto")
                    else:
                        self.model = GPT2LMHeadModel.from_pretrained(model_name)
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                    #self.model.config.max_length = 256



        print(len(self.tokenizer))
        #print("vocab list",self.tokenizer.get_vocab())
        #print(self.model.vocab_size)
        #self.tokenizer=T5Tokenizer.from_pretrained(model_name,)
        #self.model=T5ForConditionalGeneration.from_pretrained(model_name)

        self.num_beams=num_beams
        #self.bleu = BLEUScore()
        #self.dataset=dataset
        self.task=task
        self.max_num_steps=max_num_steps
        self.validation_step_outputs = []


    def forward(self,input_ids, attention_mask,labels, depth = None):
        if self.composition:
            try:
                loss=self.model(input_ids=input_ids,compositional_ids=depth, attention_mask=attention_mask,labels=labels).loss
                print("depth")
                #except:
                #    loss = self.model(input_ids=input_ids,  attention_mask=attention_mask,labels=labels).loss
                #    print("no_depth")
            except:
                loss = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
                print("without depth")
        else:
            loss=self.model(input_ids=input_ids,attention_mask=attention_mask,labels=labels).loss
        return loss


    def EM_score(self,predicted_proof_list, target_proof_list):
        EM_result = [int(a == b) for a, b in zip(predicted_proof_list, target_proof_list)]
        EM = sum(EM_result) / len(EM_result)
        return EM

    def ACC(self,predicted_proof_list, target_proof_list):
        EM_result=[0 if not a else int(a[-1] == b[-1]) for a, b in zip(predicted_proof_list, target_proof_list)]   #EM_result = [int(a[-1] == b[-1]) for a, b in zip(predicted_proof_list, target_proof_list)]
        EM = sum(EM_result) / len(EM_result)
        return EM

    def F1_score(self, predicted_proof_list, target_proof_list):
        sample_f1_score = 0
        sum_sample_f1_score = 0
        for generated_proof, golden_proof in zip(predicted_proof_list, target_proof_list):
            total_golden_steps = len(golden_proof)
            total_generated_steps = len(generated_proof)

            correct_steps = 0
            for i in range(min(total_golden_steps, total_generated_steps)):
                if golden_proof[i] == generated_proof[i]:
                    correct_steps += 1
            if total_generated_steps!=0:
                precision = correct_steps / total_generated_steps
            else:
                precision=0
            recall = correct_steps / total_golden_steps

            if precision == 0 and recall == 0:
                sample_f1_score = 0

            else:
                sample_f1_score = 2 * precision * recall / (precision + recall)

            sum_sample_f1_score += sample_f1_score

        f1_score = sum_sample_f1_score / len(target_proof_list)
        return f1_score

    def First_wrong_step_index(self,predicted_proof_list, target_proof_list):
        correct_length = []
        for a, b in zip(predicted_proof_list, target_proof_list):
            correct_flag = True
            length = min(len(a), len(b))
            first_wrong_idx = 0
            for i in range(length):
                if a[i] != b[i]:
                    correct_flag = False
                    first_wrong_idx = i
                    break
            if correct_flag == False:
                correct_length.append(first_wrong_idx)

            else:
                correct_length.append(len(b))  # indicating the proof is all correct
        AVG_First_wrong_step = sum(correct_length) / len(target_proof_list)

        return AVG_First_wrong_step

    def Num_match(self,predicted_proof_list, target_proof_list):
        sample_num_match_list = []
        for generated_proof, golden_proof in zip(predicted_proof_list, target_proof_list):
            total_steps = min(len(golden_proof),len(generated_proof))
            correct_steps = 0

            for i in range(total_steps):
                if golden_proof[i] == generated_proof[i]:
                    correct_steps += 1
            if total_steps!=0:
                sample_num_match_list.append(float(correct_steps) / total_steps)
            else:
                sample_num_match_list.append(0)

        AVG_Num_match = sum(sample_num_match_list) / len(target_proof_list)

        return AVG_Num_match

    def Num_match_independent_order(self,predicted_proof_list, target_proof_list):
        num_matching_steps_list = []
        for generated_proof, golden_proof in zip(predicted_proof_list, target_proof_list):
            golden_set = set(generated_proof)
            generated_set = set(golden_proof)
            num_matching_steps = len(golden_set.intersection(generated_set))
            num_matching_steps_list.append(num_matching_steps / len(golden_proof))

        AVG_Num_match_independent_order = sum(num_matching_steps_list) / len(target_proof_list)
        return AVG_Num_match_independent_order

    def generate_proof_step(
            self,
            input_text: List[str],
    ) -> Tuple[List[str], Any]:
        """
        Generate a single proof step with text-to-text transformers.
        """
        input = self.tokenizer(
            input_text,
            padding="longest",
            max_length=self.trainer.datamodule.max_input_len,  # type: ignore
            truncation=True,
            return_tensors="pt",
        )
        if 'gpt2' in self.model_name:
            output = self.model.generate(
                input_ids=input.input_ids.to(self.device, non_blocking=True),
                attention_mask=input.attention_mask.to(self.device, non_blocking=True),
                max_length=self.trainer.datamodule.max_output_len*2,  # type: ignore
                num_beams=5,
                num_return_sequences=5,
                early_stopping=True,
                output_scores=True,
                return_dict_in_generate=True,
            )
        else:
            output = self.model.generate(
                input_ids=input.input_ids.to(self.device, non_blocking=True),
                attention_mask=input.attention_mask.to(self.device, non_blocking=True),
                max_length=self.trainer.datamodule.max_output_len,  # type: ignore
                num_beams=5,
                num_return_sequences=5,
                early_stopping=True,
                output_scores=True,
                return_dict_in_generate=True,
            )
        if 't5' in self.model_name:
            output_text = self.tokenizer.batch_decode(
                output.sequences, skip_special_tokens=True,spaces_between_special_tokens = False
            )  # 내가 아는 그 output text이거 사용할 수 있음!!

        else:
            output_text = self.tokenizer.batch_decode(
                output.sequences, skip_special_tokens=True
            )  # 내가 아는 그 output text이거 사용할 수 있음!!

        #print("output_text_skip=true",output_text)
        # print("output_text",output_text)
        #print("output_text_skip=false",   self.tokenizer.batch_decode(output.sequences, skip_special_tokens=False))
        batch_size = len(input_text)

        k = len(output_text) // batch_size  # k predicted steps for each example.
        output_text = [output_text[i * k: (i + 1) * k] for i in range(batch_size)]
        output_scores = output.sequences_scores.detach().exp().cpu().numpy()

        assert 0.0 <= output_scores.min() <= output_scores.max() <= 1.0
        output_scores = [output_scores[i * k: (i + 1) * k] for i in range(batch_size)]

        return output_text, output_scores           #[[0.9,0.6,0.5],[0.84,0.8,0.4]]  이런식으로 batch마다 candidate의 text,score나옴



    def generate_greedy_proofs(
            self, input_list: List[str]           #[text,text,text...]
    ) -> Tuple[List[str], List[List[float]]]:
        """
        Greedily stepwise proof generation.
        """

        input_list=input_list
        unfinished_indexes = list(range(len(input_list)))
        all_step_scores=[[] for _ in input_list]
        all_partial_proof = [[] for _ in input_list]
        input_partial_proof_pair=zip(input_list,all_partial_proof)
        finished_indexes = []
        for _ in range(self.max_num_steps):
            if len(unfinished_indexes) == 0:
                # All examples in the batch has been finished.
                break
            input_text = [
                f"{input}{' ' if len(partial_proof)==0 else (';'.join(partial_proof)+';')}"
                # @greedy하게 구한 step을 prooftext에 더하면서 for loop 돌림 (['sent3 & sent1 -> int: burning converts the chemical energy in wood into thermal energy;') 추가됨
                for input,partial_proof in input_partial_proof_pair
            ]


            output_text, output_scores = self.generate_proof_step(input_text)           #output_text, output_scores는 1짜리 list, 그러나 리스트 안의 리스트(batch 1)이기 때문 [[step1,....step10]],[[score1,score2,...score10]]

            proof_steps = [  # @greedy니까 하나만 남긴다.
                steps[0] if len(steps) > 0 else None for steps in output_text
            ]
            if 'gpt2' in self.model_name:
                proof_steps=[steps[len(input_text[i]):] for i,steps in enumerate(proof_steps)]

            scores = [s[0] if len(s) > 0 else 0.0 for s in output_scores]


            for i, j in enumerate(unfinished_indexes):  # @[0,1]... batch로 들어온 sample중 proof 안 끝난 것
                step = proof_steps[j]
                if step=='0' or step=='1' or len(step)==0:
                    finished_indexes.append(j) #j
                all_partial_proof[j].append(step)                                #all_partial_proof[j].append(step)
                all_step_scores[j]=scores[j]

            unfinished_indexes = [  # @ proof 끝난 sample은 finished_index에 index append 됨. unfinished에서는 pop되는 것
                    j for i, j in enumerate(unfinished_indexes) if j not in finished_indexes #j
                ]

            input_partial_proof_pair = [(input,partial_proof) for i, (input,partial_proof) in enumerate(zip(input_list, all_partial_proof)) ]

            
        assert (
                pt[-1]=='0' or pt[-1]=='1'
                for pt in all_partial_proof
            )
        print("input_list",input_list)
        print("all_partial_proof",all_partial_proof)
        return all_partial_proof, all_step_scores

    def encoder_predict(self,input_ids,attention_mask):
        logits = self.model(input_ids, attention_mask).logits
        return logits

    def inference(
        self, input_text: List[str], compositional_ids=None,
    ) -> Tuple[List[str], List[float]]:
        """
        Single-shot proof generation with text-to-text transformers.
        """
        assert self.trainer is not None
        input = self.tokenizer(
            input_text,
            padding="longest",
            max_length=self.trainer.datamodule.max_input_len,  # type: ignore
            truncation=True,
            return_tensors="pt",
        )
        if 'gpt2' in self.model_name:

            output = self.model.generate(
                input_ids=input.input_ids.to(self.device, non_blocking=True),
                attention_mask=input.attention_mask.to(self.device, non_blocking=True),
                max_length=self.trainer.datamodule.max_output_len*2,  # type: ignore
                num_beams=self.num_beams,
                num_return_sequences=1,
                early_stopping=True,
                output_scores=True,
                return_dict_in_generate=True,
            )

        else:
            try:
                output = self.model.generate(
                    input_ids=input.input_ids.to(self.device, non_blocking=True),
                    attention_mask=input.attention_mask.to(self.device, non_blocking=True),
                    compositional_ids=compositional_ids,
                    max_length=self.trainer.datamodule.max_output_len,  # type: ignore
                    num_beams=self.num_beams,
                    num_return_sequences=1,
                    early_stopping=True,
                    output_scores=True,
                    return_dict_in_generate=True,
                )
                print("depth")
            except:
                output = self.model.generate(
                    input_ids=input.input_ids.to(self.device, non_blocking=True),
                    attention_mask=input.attention_mask.to(self.device, non_blocking=True),
                    max_length=self.trainer.datamodule.max_output_len,  # type: ignore
                    num_beams=self.num_beams,
                    num_return_sequences=1,
                    early_stopping=True,
                    output_scores=True,
                    return_dict_in_generate=True,
                )
                print("no_depth")
        if 't5' in self.model_name:
            output_text = self.tokenizer.batch_decode(
                output.sequences, skip_special_tokens=True,spaces_between_special_tokens = False
            )
        else:
            output_text = self.tokenizer.batch_decode(
                output.sequences, skip_special_tokens=True
            )
        print("output_text", output_text)
        if 'gpt2' in self.model_name:
            output_text=[output_text[i][len(input_text[i]):] for i in range(len(input_text))]
            print("GPT_output_text", output_text)
        return output_text


    def training_step(self, batch:Batch, batch_idx: int) ->None:
        #print("train")
        if self.composition:
            if 't5' in self.model_name:
                try:
                    loss = self.forward(batch["input_seq_ids"], batch["input_seq_mask"], batch["output_seq_ids"],batch["depth"])
                    print("composition with depth")
                except:
                    loss = self.forward(batch["input_seq_ids"], batch["input_seq_mask"], batch["output_seq_ids"])
                    print("composition without depth")
            self.log("loss_train", loss, on_epoch=True, sync_dist=True)

        else:
            if self.model_type=='seq2seq':
                if 't5' in self.model_name:
                    print("input_text", batch["input_text"])
                    print("label=proof", batch["output_seq"])
                    print("input_text_ids", self.tokenizer.batch_decode(batch["input_seq_ids"], skip_special_tokens=True,spaces_between_special_tokens = False))
                    loss=self.forward(batch["input_seq_ids"],batch["input_seq_mask"],batch["output_seq_ids"])

                elif 'gpt2' in self.model_name:
                    print("input_text", batch["input_text"])
                    #print("raw_input_text", batch["raw_input_text"])
                    print("input_text_ids", self.tokenizer.batch_decode(batch["input_seq_ids"],skip_special_tokens=True,))
                    print("proof", batch["output_seq"])
                    loss = self.forward(batch["input_seq_ids"], batch["input_seq_mask"], batch["input_seq_ids"])

            else:
                loss=self.forward(batch["input_seq_ids"],batch["input_seq_mask"],batch["label"])
            #print("loss", loss)
            self.log("loss_train",loss,on_epoch=True,sync_dist=True)

        return {"loss": loss}



    def validation_step(self,batch:Batch, batch_idx:int):               #
        return self.val_test_step("val", batch, batch_idx)


    def test_step(self, batch, batch_idx):
        return self.val_test_step("val", batch, batch_idx)



    def val_test_step(self, split: str, batch: Batch, batch_idx: int) -> Tuple[Any]:

        if self.task=='classification':
            if self.model_type == 'seq2seq':
                loss = self.forward(batch["input_seq_ids"], batch["input_seq_mask"], batch["output_seq_ids"])
                target_answer = batch["output_seq"]
                self.log(f"loss_{split}", loss, sync_dist=True)
                predicted_answer = self.inference(batch["input_text"])
                # print("decoded_text1",decoded_text)
                # print("explanation_label1", explanation_label)
                # print("predicted_answer",predicted_answer)
                # print("target_answer",target_answer)
                output=(predicted_answer, target_answer)
                self.validation_step_outputs.append(output)
                return predicted_answer, target_answer
            else:
                loss = self.forward(batch["input_seq_ids"], batch["input_seq_mask"], batch["label"])
                target_answer = batch["label"]
                logit = torch.sigmoid(self.encoder_predict(batch["input_seq_ids"], batch["input_seq_mask"]))
                predict_answer=torch.argmax(torch.sigmoid(self.encoder_predict(batch["input_seq_ids"], batch["input_seq_mask"])),
                             dim=-1)
                targe_answer= batch["label"]
                output=(predict_answer, target_answer)
                self.validation_step_outputs.append(output)
                return predict_answer, targe_answer


        #answer forms like '0' or '1'

        elif self.task=='one-shot':
            target_proof = batch["output_seq"]

            if 't5' in self.model_name:
                try:
                    loss = self(
                            input_ids=batch["input_seq_ids"],
                            depth=batch["depth"],
                            attention_mask=batch["input_seq_mask"],
                            labels=batch["output_seq_ids"],
                        )
                    print("depth")
                except:
                    loss = self(
                            input_ids=batch["input_seq_ids"],
                            attention_mask=batch["input_seq_mask"],
                            labels=batch["output_seq_ids"],
                        )
                    print("no_depth")


            elif 'gpt2' in self.model_name:
                loss = self(
                    batch["input_seq_ids"],
                    batch["input_seq_mask"],
                    batch["input_seq_ids"],
                )

            self.log(f"loss_{split}", loss, sync_dist=True)
            try:
                generated_proof = self.inference(batch["input_text"],batch["depth"])
            except:
                generated_proof = self.inference(batch["input_text"])
            print("input",batch["input_text"])
            print("generated_proof",generated_proof)
            #print("decoded_text1",decoded_text)
            #print("explanation_label1", explanation_label)
            output = (generated_proof,target_proof)
            self.validation_step_outputs.append(output)
            return generated_proof,target_proof         #Both proof forms like '~(1+~0);~(1+1);~1;0'


        elif self.task=='stepwise':
            label = batch["entire_proof"]
            if 't5' in self.model_name:
                loss = self(
                        batch["input_seq_ids"],
                        batch["input_seq_mask"],
                        batch["output_seq_ids"],
                    )

            elif 'gpt2' in self.model_name:
                loss = self(
                    batch["input_seq_ids"],
                    batch["input_seq_mask"],
                    batch["input_seq_ids"],
                )

            self.log(f"loss_{split}", loss, sync_dist=True)
            decoded_text = self.generate_greedy_proofs(batch["raw_input_text"])         #[text,text,text...]
            #print("decoded_text1",decoded_text)
            #print("explanation_label1", explanation_label)
            output = (decoded_text[0], label)
            self.validation_step_outputs.append(output)
            return decoded_text[0],label        #why [0] ??


    def on_validation_epoch_end(self):         #input으로 들어오는 output:  output of val_test_step (proof_pred, score,batch["context"],batch["hypothesis"],batch["proof_gt"])
        #validation이 끝나고 자동으로 호출됨, score 매기는 부분

        # called at the end of the validation epoch
        # outputs is an array with what you returned in validation_step for each batch
        # outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}]
        print("valid_step_outputs",self.validation_step_outputs)

        # do something with all preds
        outputs=copy.deepcopy(self.validation_step_outputs)
        self.validation_step_outputs.clear()  # free memory
        return self.val_test_epoch_end("val", outputs)

    def on_test_epoch_end(
            self):  # input으로 들어오는 output:  output of val_test_step (proof_pred, score,batch["context"],batch["hypothesis"],batch["proof_gt"])
        # validation이 끝나고 자동으로 호출됨, score 매기는 부분

        # called at the end of the validation epoch
        # outputs is an array with what you returned in validation_step for each batch
        # outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}]
        print("test_step_outputs", self.validation_step_outputs)

        # do something with all preds
        outputs = copy.deepcopy(self.validation_step_outputs)
        self.validation_step_outputs.clear()  # free memory
        return self.val_test_epoch_end("val", outputs)

    def val_test_epoch_end(self, split: str, outputs: Iterable[Any]) -> None:               #
        #output is data(decoded_text,explanation_label) from all validation(test) set[(decoded_text,explanation_label),(decoded_text,explanation_label),(decoded_text,explanation_label)...]

        if self.task == 'classification':
            predicted_answer_list=[]
            target_answer_list=[]
            for out in outputs:
                for predicted_answer,target_answer  in zip(*out):
                    predicted_answer_list.append(predicted_answer)
                    target_answer_list.append(target_answer)

            metric=BinaryAccuracy()
            Classification_ACC= sum([int(a == b) for a, b in zip(predicted_answer_list, target_answer_list)])/len(target_answer_list)
            print("length",len(target_answer_list))
            print("sum",sum([int(a == b) for a, b in zip(predicted_answer_list, target_answer_list)]))

            print("Classification_ACC",Classification_ACC)
            self.log(f"Classification_ACC_{split}", Classification_ACC, sync_dist=True)


        elif self.task=='one-shot':         #outputs is list,[([1 batch's generated_proof ],[1 batch's target_proof] ),([1 batch's generated_proof],[1 batch's target_proof] ),(generated_proof,target_proof )]
            predicted_proof_list = []       #1 batch's target_proof=['((0+0)*(~0*0));((0+0)*(1*0));(0*(1*0));(0*0);0','((1*~1)*0);((1*0)*0);(0*0);0', ...]
            target_proof_list = []
            #print("outputs",outputs)
            for out in outputs:
                for predicted_proof, target_proof in zip(*out):

                    predicted_proof = predicted_proof.split(';')
                    target_proof= target_proof.split(';')
                    
                    predicted_proof_list.append(predicted_proof)
                    target_proof_list.append(target_proof)

            #EM_result = [int(a == b) for a, b in zip(predicted_proof_list, target_proof_list)]
            One_shot_EM=self.EM_score(predicted_proof_list, target_proof_list)
            with open(f"./{self.model_name}{self.task}_generated_proof7.jsonl", "a", encoding="utf-8") as f:
                for i in range(len(predicted_proof_list)):

                    generated_proof = OrderedDict()
                    length = min(len(predicted_proof_list[i]), len(target_proof_list[i]))
                    correct_flag = True
                    for j in range(length):
                        if predicted_proof_list[i][j] != target_proof_list[i][j]:
                            correct_flag = False
                            First_wrong_index = j
                            break
                    # indicating the proof is all correct
                    if correct_flag == True:
                        First_wrong_index = "correct"

                    generated_proof["First_wrong_index"] = First_wrong_index

                    generated_proof["generated_proof"] = predicted_proof_list[i]
                    generated_proof["target"] = target_proof_list[i]

                    json.dump(generated_proof, f)  #
                    f.write("\n")
                f.close()

            correct_length=[]
            for a,b in zip(predicted_proof_list, target_proof_list):
                correct_flag=True
                length=min(len(a),len(b))
                first_wrong_idx=0
                for i in range(length):
                    if a[i]!=b[i]:
                        correct_flag=False
                        first_wrong_idx=i
                        break
                if correct_flag==False:
                   correct_length.append(first_wrong_idx)

                else:
                    correct_length.append(len(b))           #indicating the proof is all correct
            AVG_First_wrong_step=self.First_wrong_step_index(predicted_proof_list, target_proof_list)
            F1_score=self.F1_score(predicted_proof_list, target_proof_list)
            Num_match=self.Num_match(predicted_proof_list, target_proof_list)
            Num_match_independent_order=self.Num_match_independent_order(predicted_proof_list, target_proof_list)
            ACC = self.ACC(predicted_proof_list, target_proof_list)
            print("One_shot_EM", One_shot_EM)
            print("AVG_First_wrong_step", AVG_First_wrong_step)
            print("ACC", ACC)
            self.log(f"One_shot_EM_{split}", One_shot_EM, sync_dist=True)
            self.log(f"AVG_First_wrong_step_{split}", AVG_First_wrong_step, sync_dist=True)
            self.log(f"F1_score_{split}", F1_score, sync_dist=True)
            self.log(f"Num_match_{split}", Num_match, sync_dist=True)
            self.log(f"Num_match_independent_order_{split}", Num_match_independent_order, sync_dist=True)
            self.log(f"ACC{split}", ACC, sync_dist=True)
            #self.log(f"Num_val_{split}", len(target_proof_list), sync_dist=True)

        elif self.task=='stepwise':
            predicted_proof_list = []
            target_proof_list = []
            for out in outputs:
                for predicted_proof, target_proof in zip(*out):
                    predicted_proof_list.append(predicted_proof)
                    target_proof = target_proof.split(';')
                    target_proof_list.append(target_proof)

            EM_result = [int(a == b) for a, b in zip(predicted_proof_list, target_proof_list)]
            Step_EM=self.EM_score(predicted_proof_list, target_proof_list)
            with open(f"./{self.model_name}{self.task}_generated_proof7.jsonl", "a", encoding="utf-8") as f:
                for i in range(len(predicted_proof_list)):

                    generated_proof = OrderedDict()
                    length = min(len(predicted_proof_list[i]), len(target_proof_list[i]))
                    correct_flag = True
                    for j in range(length):
                        if predicted_proof_list[i][j] != target_proof_list[i][j]:
                            correct_flag = False
                            First_wrong_index = j
                            break
                    # indicating the proof is all correct
                    if correct_flag == True:
                        First_wrong_index = "correct"

                    generated_proof["First_wrong_index"] = First_wrong_index

                    generated_proof["generated_proof"] = predicted_proof_list[i]
                    generated_proof["target"] = target_proof_list[i]

                    json.dump(generated_proof, f)  #
                    f.write("\n")
            f.close()

            correct_length=[]
            for a,b in zip(predicted_proof_list, target_proof_list):
                correct_flag=True
                length=min(len(a),len(b))
                first_wrong_idx=0
                for i in range(length):
                    if a[i]!=b[i]:
                        correct_flag=False
                        first_wrong_idx=i
                        break
                if correct_flag==False:
                   correct_length.append(first_wrong_idx)


            AVG_First_wrong_step=self.First_wrong_step_index(predicted_proof_list, target_proof_list)
            F1_score=self.F1_score(predicted_proof_list, target_proof_list)
            Num_match=self.Num_match(predicted_proof_list, target_proof_list)
            Num_match_independent_order=self.Num_match_independent_order(predicted_proof_list, target_proof_list)
            ACC=self.ACC(predicted_proof_list, target_proof_list)
            print("Step_EM ", Step_EM)
            print("AVG_First_wrong_step",AVG_First_wrong_step)
            self.log(f"Step_EM_{split}", Step_EM, sync_dist=True)
            self.log(f"AVG_First_wrong_step_{split}", AVG_First_wrong_step, sync_dist=True)
            self.log(f"F1_score_{split}", F1_score, sync_dist=True)
            self.log(f"Num_match_{split}", Num_match, sync_dist=True)
            self.log(f"Num_match_independent_order_{split}", Num_match_independent_order, sync_dist=True)
            self.log(f"ACC{split}", ACC, sync_dist=True)
            self.log(f"Num_val_{split}", len(target_proof_list), sync_dist=True)

    def configure_optimizers(self) -> Dict[str, Any]:
        assert self.trainer is not None
        if self.trainer.max_steps != -1:
            max_steps = self.trainer.max_steps
        else:
            max_steps = (
                    self.trainer.max_epochs
                    * len(self.trainer.datamodule.train_dataloader())  # type: ignore
                    // self.trainer.accumulate_grad_batches
            )
        return get_optimizers(
            self.parameters(),
            self.lr,
            self.warmup_steps,
            max_steps,
        )


