from .comlib import *
from . import compute_grad, convert_vec_tuple, model_setup,criterion
from .. import util,save_log,dataset

class Train():
    def __init__(self,evalObjList,save_folder,
          batch_size,max_epoch,
          save_round_inteval=None,
          save_epoch_inteval=None,):
        self.evalObjList=evalObjList
        self.saveObj=SaveTrain(save_folder,self.get_header(evalObjList),initialize=True)
        self.batch_size=batch_size
        self.max_epoch=max_epoch
        self.save_round_inteval=save_round_inteval
        self.save_epoch_inteval=save_epoch_inteval

    @staticmethod
    def get_header(evalObjList):
        l=["epoch","round","weight_norm","grad_norm"]
        for evalObj in evalObjList:
            l=l+evalObj.get_header()
        return l

    def train(self,train_dataset,modelSetup:model_setup.ModelSetup,
          optimizer,criterion=nn.CrossEntropyLoss()):
        batch_size=self.batch_size
        batch_per_epoch=np.ceil(len(train_dataset)/batch_size)

        dataloader=DataLoader(train_dataset,batch_size=batch_size,shuffle=False)
        # self.eval_and_save(-1,-1,modelSetup,None)
        grad_tuple=None

        for epoch in range(self.max_epoch):
            for batch_idx, (data, target) in enumerate(dataloader):
                self.eval_and_save(epoch,batch_per_epoch,batch_idx,modelSetup,grad_tuple)
                input_args=(data,)
                output_args=(target,)
                grad_tuple=modelSetup.calcuJacobian(input_args,criterion,output_args)
                optimizer.stepByGradTuple(grad_tuple)

        self.eval_and_save(epoch+1,batch_per_epoch,0,modelSetup,grad_tuple)
                
    def eval_and_save_(self,round,epoch,modelSetup,grad_tuple=None):
        row={}
        norms=self.get_norms(modelSetup,grad_tuple)
        row.update(norms)
        # print("saveObj",self.saveObj.saveCsvObj.header)
        for evalObj in self.evalObjList:
            r=evalObj.get(modelSetup)
            row.update(r)        
        self.saveObj.save(epoch=epoch,round=round,**row)

    def eval_and_save(self,epoch,batch_per_epoch,batch_idx,modelSetup,grad_tuple=None):
        round=batch_idx+epoch*batch_per_epoch
        flag=0
        if self.save_round_inteval is not None:
            if round%self.save_round_inteval==0:
                flag=1
        if (self.save_epoch_inteval is not None) & (batch_idx==0):
            if epoch%self.save_epoch_inteval==0:
                flag=1
        if flag==1:
            self.eval_and_save_(round,epoch,modelSetup,grad_tuple)


    @staticmethod
    def get_weight_norm(modelSetup):
        weight_norm=modelSetup.gettWeightNorm().item()
        return weight_norm

    @staticmethod
    def get_grad_norm(grad_tuple=None):
        grad_norm=None
        if grad_tuple is not None:
            grad_norm=model_setup.ModelPara(grad_tuple).norm().item()
        return grad_norm
    
    @staticmethod
    def get_norms(modelSetup,grad_tuple):
        norms={}
        norms["weight_norm"]=Train.get_weight_norm(modelSetup)
        norms["grad_norm"]=Train.get_grad_norm(grad_tuple)
        return norms

class Eval():
    def __init__(self,dataset,dataset_name,criterions:dict,
          batch_size):
        self.dataset=dataset
        self.dataset_name=dataset_name
        self.criterions=criterions
        self.batch_size=batch_size

    @staticmethod
    def getModelDatasetCriterion(modelSetup:model_setup.ModelSetup,criterion,dataset,
         batch_size):

        value=modelSetup.getDatasetCriterion(model_setup.Argset(dataset),criterion,batch_size,reduction="mean")

        return value
    
    @staticmethod
    def getComposedCriterion(criterions):
        cri_names=list(criterions.keys())
        def composedCriterion(out,target):
            r=np.zeros(len(criterions))
            for i,name in enumerate(cri_names):
                criterion=criterions[name]
                r[i]=criterion(out,target).item()
            return r
        return composedCriterion,cri_names


    def get(self,modelSetup:model_setup.ModelSetup):
        composedCriterion,cri_names=self.getComposedCriterion(self.criterions)
        r_dict={}
        temp=self.getModelDatasetCriterion(modelSetup,composedCriterion,self.dataset,
        self.batch_size)
        for i,cri_name in enumerate(cri_names):
            r_dict[f"{cri_name}_{self.dataset_name}"]=temp[i]

        return r_dict
    
    def get_header(self):
        return [f"{cri_name}_{self.dataset_name}" for cri_name in self.criterions]
    
    

def get_evalObjList(train_dataset,test_dataset,batch_size):
    dataset_dict={"train":train_dataset,
                    "test":test_dataset,}
    criterions={
        "loss":nn.CrossEntropyLoss(),
        "acc":criterion.accuracy(reduction="mean")
    }
    evalObjList=[]
    for dname, dset in dataset_dict.items():
        evalObjList.append(Eval(dset,dname,criterions,batch_size))
    return evalObjList


class SaveTrain():
    def __init__(self,save_folder,header=None,initialize=True,x_name="round"):
        self.save_folder=save_folder
        self.save_file=f"{save_folder}/train.csv"
        self.saveCsvObj=save_log.SaveCsvHeader(self.save_file,header)
        if initialize:
            self.saveCsvObj.initialize()

        self.x_name=x_name
        self.cur_row={}

    def add_data_to_row(self,data_dict):
        self.cur_row.update(data_dict)
    def save_cur_row(self):
        self.saveCsvObj.append_data([self.cur_row])
        self.cur_row={}

    def save(self,**kwargs):
        self.saveCsvObj.append_data([kwargs])

    def read_to_df(self):
        return self.saveCsvObj.read_to_df()
    
    def get_line(self,df:pd.DataFrame,col_name):
        return save_log.Line(df.loc[:,self.x_name],df.loc[:,col_name])


def plot(save_folder):
    figs={"loss":{},"acc":{},
        "weight_norm":{},
        "grad_norm":{}}
    saveObj=SaveTrain(save_folder,initialize=False)
    df=saveObj.read_to_df()
    for fig_name in ["loss","acc"]:
        lines={}
        for line_name in ["train","test"]:
            lines[line_name]=saveObj.get_line(df,f"{fig_name}_{line_name}")
        figs[fig_name]=lines

    for fig_name in ["weight_norm","grad_norm"]:
        figs[fig_name][0]=saveObj.get_line(df,fig_name)
    for name in figs:
        save_log.plot1(figs[name],f"{save_folder}/{name}.png")