from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union

import os
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import TrainerCallback, TrainingArguments
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer,set_seed,DataCollatorWithPadding
from transformers.trainer import Trainer
import evaluate
from copy import deepcopy
from peft import PeftModel, PeftConfig
import json

@dataclass
class CoHTrainArgs(TrainingArguments):
    # training args
    learning_rate: float = 5e-4
    warmup_steps: int = 10000
    weight_decay: float = field(default=0, metadata={"help": "typically, set this to 0.01"})
    per_device_train_batch_size: int = 1
    per_device_eval_batch_size: int = 1
    dataloader_drop_last: bool = False
    report_to: str = 'wandb'
    output_dir: str = 'outputs'
    logging_steps: int = 100
    dataloader_num_workers: int = 0  # TODO
    gradient_accumulation_steps: int = 1
    log_on_each_node: bool = False
    deepspeed: str =""
    gradient_checkpointing: bool = False
    pretrain_task: bool = False
    generate_max_length: int =512
    ############### COH ARGS ##################
    pt_loss_weight: float = field(default=1.0, metadata={"help": "Pretrain Data loss weight."})
    num_virtual_tokens: int = 0
    tokens_lr_times: int = 1
    train_method: str = None
    ####################################################################
    ############################ DO NOT CHANGE!! #######################
    ####################################################################
    remove_unused_columns: bool = False  # since CoHDataset uses non-standard columns

torch.autograd.set_detect_anomaly(True)
class CoHTrainer(Trainer):
    """
    TODO: implement forgetful causal masking (fcm)
    """
    def compute_loss(self, model, inputs, flag=False,return_outputs=False):

        if self.args.train_method != "coh":
            hf_input_ids = inputs['input_ids'][:, :-1]
            targets = inputs['input_ids'][:, 1:]
            good_masks = inputs['good_loss_masks'][:, 1:]
            bad_masks = inputs['bad_loss_masks'][:, 1:]
        
            total_loss = model(input_ids=hf_input_ids,good_masks=good_masks,bad_masks=bad_masks,labels = targets)["loss"]
            
            if return_outputs:
                return total_loss, [None, None]  # fake outputs
            return total_loss

        else:
            hf_input_ids = inputs['input_ids'][:, :-1]
            targets = inputs['input_ids'][:, 1:]
            masks = inputs['masks'][:, 1:]
            hf_logits = model(input_ids=hf_input_ids).logits
            hf_logits = hf_logits.contiguous()
            hf_loss = F.cross_entropy(hf_logits.permute(0, 2, 1), targets, reduction='none')
            hf_loss = (hf_loss * masks).mean()

            if return_outputs:
                return hf_loss, [None, None]  # fake outputs
            return hf_loss


class MyCollator(object):
    def __init__(self, args, tokenizer):
        self.args = args
        self.tokenizer = tokenizer
    def __call__(self, batch):
        collated = {k: [] for k in batch[0].keys()}
        for x in batch:
            for k, v in x.items():
                collated[k].append(v.view(-1))
        if self.args.method != "coh":
            collated["input_ids"] = torch.nn.utils.rnn.pad_sequence(collated["input_ids"], batch_first=True, padding_value=self.tokenizer.pad_token_id)
            collated["good_loss_masks"] = torch.nn.utils.rnn.pad_sequence(collated["good_loss_masks"], batch_first=True, padding_value=0)
            collated["bad_loss_masks"] = torch.nn.utils.rnn.pad_sequence(collated["bad_loss_masks"], batch_first=True, padding_value=0)
        else:
            collated["input_ids"] = torch.nn.utils.rnn.pad_sequence(collated["input_ids"], batch_first=True, padding_value=self.tokenizer.pad_token_id)
            collated["masks"] = torch.nn.utils.rnn.pad_sequence(collated["masks"], batch_first=True, padding_value=0)
        collated = {k: v for k, v in collated.items()}
        return collated
class AccCallback(TrainerCallback):

    def __init__(self, test_dataset, logger, coh_train_args, model, tokenizer,args):
        self.model = model
        self.logger = logger
        self.coh_train_args = coh_train_args
        self.args = args
        self.test_dataset = test_dataset
        self.tokenizer = tokenizer
        self.collate = MyCollator(self.args, self.tokenizer)
        self.dataloader = DataLoader(
            test_dataset,
            batch_size=coh_train_args.per_device_eval_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,  # TODO
            collate_fn=self.collate,
        )

    def on_epoch_end(self, args, state, control, **kwargs):
        if self.args.method != "coh":
            correct_rate = []
            for inputs in tqdm(self.dataloader, desc='Evaluating on Test Set'):
                with torch.no_grad():

                    hf_input_ids = inputs['input_ids'][:, :-1]
                    targets = inputs['input_ids'][:, 1:]
                    good_masks = inputs['good_loss_masks'][:, 1:]
                    bad_masks = inputs['bad_loss_masks'][:, 1:]
                    final_masks = good_masks + bad_masks

                    final_logits = self.model(input_ids=hf_input_ids,good_masks=good_masks,bad_masks=bad_masks,labels = targets)["logits"]
                    
                    pred_ids = torch.argmax(final_logits, dim=-1)
                    for pred_id,target,mask in zip(pred_ids,targets.to(pred_ids.device),final_masks):
                        num_correct = 0
                        num_predict = 0
                        for p,e,m in zip(pred_id,target,mask):
                            if m!=0:
                                if p==e:
                                    num_correct+=1
                                num_predict = num_predict + 1
                        correct_rate.append(num_correct/num_predict)
            print({'accuracy': sum(correct_rate)/len(correct_rate)})
            if self.coh_train_args.local_rank == 0:
                with open(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,"metrics_result.jsonl"), "a+",encoding="utf-8") as f1:
                    f1.write(json.dumps({'accuracy': sum(correct_rate)/len(correct_rate)}) + "\n")
                self.logger.log({'accuracy': sum(correct_rate)/len(correct_rate)})


        else:
            correct_rate = []
            self.model.eval()
            for inputs in tqdm(self.dataloader, desc='Evaluating on Test Set'):
                with torch.no_grad():

                    hf_input_ids = inputs['input_ids'][:, :-1].to(self.model.device)
                    targets = inputs['input_ids'][:, 1:].to(self.model.device)
                    masks = inputs['masks'][:, 1:].to(self.model.device)
                    hf_logits = self.model(input_ids=hf_input_ids).logits
                    hf_loss = F.cross_entropy(hf_logits.permute(0, 2, 1), targets, reduction='none')
                    hf_loss = (hf_loss * masks).mean()

                    pred_ids = torch.argmax(hf_logits, dim=-1)
                    
                    for pred_id,target,mask in zip(pred_ids,targets.to(pred_ids.device),masks):
                        num_correct = 0
                        num_predict = 0
                        for p,e,m in zip(pred_id,target,mask):
                            if m!=0:
                                if p==e:
                                    num_correct+=1
                                num_predict = num_predict + 1
                        correct_rate.append(num_correct/num_predict)

            print({'accuracy': sum(correct_rate)/len(correct_rate)})
            if self.coh_train_args.local_rank == 0:
                with open(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,"metrics_result.jsonl"), "a+",encoding="utf-8") as f1:
                    f1.write(json.dumps({'accuracy': sum(correct_rate)/len(correct_rate)}) + "\n")
                self.logger.log({'accuracy': sum(correct_rate)/len(correct_rate)})




class RougeCallback(TrainerCallback):

    def __init__(self, test_dataset, logger, coh_train_args, model, tokenizer,args):
        self.best_rouge = 0
        self.model = model
        self.logger = logger
        self.coh_train_args = coh_train_args
        self.args = args
        self.test_dataset = test_dataset
        self.tokenizer = tokenizer
        self.collate = MyCollator(self.args, self.tokenizer)
        self.dataloader = DataLoader(
            test_dataset,
            batch_size=coh_train_args.per_device_eval_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,  # TODO
            collate_fn=self.collate,
        )           

            
    def on_epoch_end(self, args, state, control, **kwargs):
        if self.args.method != "coh":
            self.model.eval()
            final_generate_text = []
            final_target_text = []
            final_prefix_text = []
            result_list = []
            rouge = evaluate.load("rouge")
            for inputs in tqdm(self.dataloader, desc='Evaluating on Test Set'):
                with torch.no_grad():
                    good_masks = inputs['good_loss_masks']
                    target_id = inputs["target"]
                    generate_dic={}
                    generate_dic["input_ids"] = inputs["prefix"]
                    generate_dic["attention_mask"] = inputs["attention_masks"]
                    self.tokenizer.padding_side='left'
                    generate_dic = self.tokenizer.pad(generate_dic,padding=True)
                    generate_dic["input_ids"] = generate_dic["input_ids"].to(self.model.good_model.device)
                    generate_dic["attention_mask"] = generate_dic["attention_mask"].to(self.model.good_model.device)
                    input_length = len(generate_dic['input_ids'][0])

                    generated_ids = self.model.good_model.generate(
                                                    **generate_dic,
                                                    # do_sample=True,
                                                    # min_length=args.min_length,
                                                    max_length=input_length+self.coh_train_args.generate_max_length,
                                                    # top_p=1,
                                                    pad_token_id=self.tokenizer.pad_token_id,
                                                    temperature=0,
                                                    early_stopping = True
                                                )
                    generated_ids = [g[input_length:] for g in generated_ids]
                    prefix_text = self.tokenizer.batch_decode(generate_dic["input_ids"], skip_special_tokens=True)
                    generate_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
                    target_text = self.tokenizer.batch_decode(target_id, skip_special_tokens=True)
                    
                    if not os.path.exists(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name)):
                        os.makedirs(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name))
                    with open(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,"generate_result.jsonl"), "a+",encoding="utf-8") as f1:
                        for pre,gen,tar,goo in zip(prefix_text, generate_text,target_text,good_masks):
                            if len(torch.nonzero(goo)) != 0:
                                final_generate_text.append(gen)
                                final_target_text.append(tar)
                                final_prefix_text.append(pre)
                                temp={}
                                temp["prefix"]=pre
                                temp["generate_text"]=gen
                                temp["target"]=tar
                                if self.coh_train_args.local_rank == 0:
                                    f1.write(json.dumps(temp) + "\n")
                                result_list.append(list(temp.values()))

            if self.coh_train_args.local_rank == 0:
                with open(self.coh_train_args.output_dir + "/" + self.args.wandb_project_name + "/" + self.args.wandb_run_name + "/" + "generate_result.jsonl", "a+",encoding="utf-8") as f1:
                    f1.write(f"#################################New epoch {self.args.epoch_flag} result#################################" + "\n")
                
            result = rouge.compute(predictions=final_generate_text, references=final_target_text, use_stemmer=True)

            avg_rouge = 0
            for k in result.keys():
                avg_rouge = avg_rouge + result[k]
            result["avg_rouge"] = avg_rouge/len(result.keys())
            prediction_lens = [len(gen.split(' ')) for gen in final_generate_text]
            result["gen_len"] = np.mean(prediction_lens)

            
            if self.coh_train_args.local_rank == 0:
                torch.save(self.model.state_dict(),os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,f"epoch_{self.args.epoch_flag}_model"))
                self.args.epoch_flag = self.args.epoch_flag + 1

            print({k: round(v * 100, 4) for k, v in result.items()})
            if self.coh_train_args.local_rank == 0:
                self.args.epoch_flag = self.args.epoch_flag + 1
                with open(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,"metrics_result.jsonl"), "a+",encoding="utf-8") as f1:
                    f1.write(json.dumps({k: round(v * 100, 4) for k, v in result.items()}) + "\n")
                self.logger.log({k: round(v * 100, 4) for k, v in result.items()})

        else:
            self.model.eval()
            final_generate_text = []
            final_target_text = []
            final_prefix_text = []
            result_list = []
            rouge = evaluate.load("rouge")
            for inputs in tqdm(self.dataloader, desc='Evaluating on Test Set'):
                with torch.no_grad():

                    generate_dic={}
                    target_id = inputs["target"]
                    generate_dic["input_ids"] = inputs["prefix"]
                    generate_dic["attention_mask"] = inputs["attention_masks"]
                    self.tokenizer.padding_side='left'
                    generate_dic = self.tokenizer.pad(generate_dic,padding=True)
                    generate_dic["input_ids"] = generate_dic["input_ids"].to(self.model.device)
                    generate_dic["attention_mask"] = generate_dic["attention_mask"].to(self.model.device)
                    input_length = len(generate_dic['input_ids'][0])

                    generated_ids = self.model.generate(
                                                    **generate_dic,
                                                    # do_sample=True,
                                                    # min_length=args.min_length,
                                                    max_length=input_length+self.coh_train_args.generate_max_length,
                                                    # top_p=1,
                                                    pad_token_id=self.tokenizer.pad_token_id,
                                                    temperature=0,
                                                    early_stopping = True
                                                )
                    generated_ids = [g[input_length:] for g in generated_ids][::2]
                    prefix_text = self.tokenizer.batch_decode(generate_dic["input_ids"], skip_special_tokens=True)[::2]
                    generate_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[::2]
                    target_text = self.tokenizer.batch_decode(target_id, skip_special_tokens=True)[::2]
                    
                    if not os.path.exists(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name)):
                        os.makedirs(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name))
                    with open(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,"generate_result.jsonl"), "a+",encoding="utf-8") as f1:
                        for pre,gen,tar in zip(prefix_text, generate_text,target_text):
                            final_generate_text.append(gen)
                            final_target_text.append(tar)
                            final_prefix_text.append(pre)
                            temp={}
                            temp["prefix"]=pre
                            temp["generate_text"]=gen
                            temp["target"]=tar
                            if self.coh_train_args.local_rank == 0:
                                f1.write(json.dumps(temp) + "\n")
                            result_list.append(list(temp.values()))

            if self.coh_train_args.local_rank == 0:
                with open(self.coh_train_args.output_dir + "/" + self.args.wandb_project_name + "/" + self.args.wandb_run_name + "/" + "generate_result.jsonl", "a+",encoding="utf-8") as f1:
                    f1.write(f"#################################New epoch {self.args.epoch_flag} result#################################" + "\n")
                
            self.genertae_result = result_list
            result = rouge.compute(predictions=final_generate_text, references=final_target_text, use_stemmer=True)

            avg_rouge = 0
            for k in result.keys():
                avg_rouge = avg_rouge + result[k]
            result["avg_rouge"] = avg_rouge/len(result.keys())
            prediction_lens = [len(gen.split(' ')) for gen in final_generate_text]
            result["gen_len"] = np.mean(prediction_lens)

            # if result["avg_rouge"] > self.best_rouge:
            #     self.best_rouge = result["avg_rouge"]
            #     if self.coh_train_args.local_rank == 0:
            #         torch.save(self.model.state_dict(),os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,"best_model"))

            if self.coh_train_args.local_rank == 0:
                torch.save(self.model.state_dict(),os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,f"epoch_{self.args.epoch_flag}_model"))
                self.args.epoch_flag = self.args.epoch_flag + 1
            
            print({k: round(v * 100, 4) for k, v in result.items()})
            if self.coh_train_args.local_rank == 0:
                self.args.epoch_flag = self.args.epoch_flag + 1
                with open(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,"metrics_result.jsonl"), "a+",encoding="utf-8") as f1:
                    f1.write(json.dumps({k: round(v * 100, 4) for k, v in result.items()}) + "\n")
                self.logger.log({k: round(v * 100, 4) for k, v in result.items()})



class EvalCallback(TrainerCallback):

    def __init__(self, test_dataset, logger, coh_train_args, model, tokenizer,args):
        self.model = model
        self.logger = logger
        self.coh_train_args = coh_train_args
        self.args = args
        self.test_dataset = test_dataset
        self.tokenizer = tokenizer
        self.collate = MyCollator(self.args, self.tokenizer)
        self.dataloader = DataLoader(
            test_dataset,
            batch_size=coh_train_args.per_device_eval_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,  # TODO
            collate_fn=self.collate,
        )           

            
    def on_epoch_end(self, args, state, control, **kwargs):
        if self.args.method != "coh":
            self.model.eval()
            final_generate_text = []
            final_target_text = []
            final_prefix_text = []
            result_list = []
            for inputs in tqdm(self.dataloader, desc='Evaluating on Test Set'):
                with torch.no_grad():
                    good_masks = inputs['good_loss_masks']
                    target_id = inputs["target"]
                    generate_dic={}
                    generate_dic["input_ids"] = inputs["prefix"]
                    generate_dic["attention_mask"] = inputs["attention_masks"]
                    self.tokenizer.padding_side='left'
                    generate_dic = self.tokenizer.pad(generate_dic,padding=True)
                    generate_dic["input_ids"] = generate_dic["input_ids"].to(self.model.good_model.device)
                    generate_dic["attention_mask"] = generate_dic["attention_mask"].to(self.model.good_model.device)
                    input_length = len(generate_dic['input_ids'][0])

                    generated_ids = self.model.good_model.generate(
                                                    **generate_dic,
                                                    # do_sample=True,
                                                    # min_length=args.min_length,
                                                    max_length=input_length+self.coh_train_args.generate_max_length,
                                                    # top_p=1,
                                                    pad_token_id=self.tokenizer.pad_token_id,
                                                    temperature=0,
                                                    early_stopping = True
                                                )
                    generated_ids = [g[input_length:] for g in generated_ids]
                    prefix_text = self.tokenizer.batch_decode(generate_dic["input_ids"], skip_special_tokens=True)
                    generate_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
                    target_text = self.tokenizer.batch_decode(target_id, skip_special_tokens=True)
                    
                    if not os.path.exists(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name)):
                        os.makedirs(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name))
                    with open(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,"generate_result.jsonl"), "a+",encoding="utf-8") as f1:
                        for pre,gen,tar,goo in zip(prefix_text, generate_text,target_text,good_masks):
                            if len(torch.nonzero(goo)) != 0:
                                final_generate_text.append(gen)
                                final_target_text.append(tar)
                                final_prefix_text.append(pre)
                                temp={}
                                temp["prefix"]=pre
                                temp["generate_text"]=gen
                                temp["target"]=tar
                                if self.coh_train_args.local_rank == 0:
                                    f1.write(json.dumps(temp) + "\n")
                                result_list.append(list(temp.values()))

            if self.coh_train_args.local_rank == 0:
                with open(self.coh_train_args.output_dir + "/" + self.args.wandb_project_name + "/" + self.args.wandb_run_name + "/" + "generate_result.jsonl", "a+",encoding="utf-8") as f1:
                    f1.write(f"################################# New epoch {self.args.epoch_flag} result#################################" + "\n")
                
            if self.coh_train_args.local_rank == 0:
                torch.save(self.model.state_dict(),os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,f"epoch_{self.args.epoch_flag}_model"))
                self.args.epoch_flag = self.args.epoch_flag + 1
                print("Eval Finish")

        else:
            self.model.eval()
            final_generate_text = []
            final_target_text = []
            final_prefix_text = []
            result_list = []
            for inputs in tqdm(self.dataloader, desc='Evaluating on Test Set'):
                with torch.no_grad():

                    generate_dic={}
                    target_id = inputs["target"]
                    generate_dic["input_ids"] = inputs["prefix"]
                    generate_dic["attention_mask"] = inputs["attention_masks"]
                    self.tokenizer.padding_side='left'
                    generate_dic = self.tokenizer.pad(generate_dic,padding=True)
                    generate_dic["input_ids"] = generate_dic["input_ids"].to(self.model.device)
                    generate_dic["attention_mask"] = generate_dic["attention_mask"].to(self.model.device)
                    input_length = len(generate_dic['input_ids'][0])

                    generated_ids = self.model.generate(
                                                    **generate_dic,
                                                    # do_sample=True,
                                                    # min_length=args.min_length,
                                                    max_length=input_length+self.coh_train_args.generate_max_length,
                                                    # top_p=1,
                                                    pad_token_id=self.tokenizer.pad_token_id,
                                                    temperature=0,
                                                    early_stopping = True
                                                )
                    generated_ids = [g[input_length:] for g in generated_ids][::2]
                    prefix_text = self.tokenizer.batch_decode(generate_dic["input_ids"], skip_special_tokens=True)[::2]
                    generate_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[::2]
                    target_text = self.tokenizer.batch_decode(target_id, skip_special_tokens=True)[::2]
                    
                    if not os.path.exists(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name)):
                        os.makedirs(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name))
                    with open(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,"generate_result.jsonl"), "a+",encoding="utf-8") as f1:
                        for pre,gen,tar in zip(prefix_text, generate_text,target_text):
                            final_generate_text.append(gen)
                            final_target_text.append(tar)
                            final_prefix_text.append(pre)
                            temp={}
                            temp["prefix"]=pre
                            temp["generate_text"]=gen
                            temp["target"]=tar
                            if self.coh_train_args.local_rank == 0:
                                f1.write(json.dumps(temp) + "\n")
                            result_list.append(list(temp.values()))

            if self.coh_train_args.local_rank == 0:
                with open(self.coh_train_args.output_dir + "/" + self.args.wandb_project_name + "/" + self.args.wandb_run_name + "/" + "generate_result.jsonl", "a+",encoding="utf-8") as f1:
                    f1.write(f"#################################New epoch {self.args.epoch_flag} result#################################" + "\n")
                
            if self.coh_train_args.local_rank == 0:
                torch.save(self.model.state_dict(),os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,f"epoch_{self.args.epoch_flag}_model"))
                self.args.epoch_flag = self.args.epoch_flag + 1
                print("Eval Finish")

class SaveCallback(TrainerCallback):

    def __init__(self, test_dataset, logger, coh_train_args, model, tokenizer,args):
        self.model = model
        self.logger = logger
        self.coh_train_args = coh_train_args
        self.args = args
        self.test_dataset = test_dataset
        self.tokenizer = tokenizer
        self.collate = MyCollator(self.args, self.tokenizer)
        self.dataloader = DataLoader(
            test_dataset,
            batch_size=coh_train_args.per_device_eval_batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=0,  # TODO
            collate_fn=self.collate,
        )           

            
    def on_epoch_end(self, args, state, control, **kwargs):
        self.model.eval()
        if self.coh_train_args.local_rank == 0:
            if not os.path.exists(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name)):
                os.makedirs(os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name))

            torch.save(self.model.state_dict(),os.path.join(self.coh_train_args.output_dir,self.args.wandb_project_name,self.args.wandb_run_name,f"epoch_{self.args.epoch_flag}_model"))
            self.args.epoch_flag = self.args.epoch_flag + 1