from .comlib import *
from . import compute_grad, convert_vec_tuple, model_setup,criterion
from .. import util,save_log,dataset

class Train():
    def __init__(self,evalObjs,save_folder,
          batch_size,max_epoch,save_name="train",
          save_round_inteval=None,
          save_epoch_inteval=None,):
        self.evalObjs=evalObjs
        self.saveObj=SaveTrain(save_folder,self.get_header(evalObjs),initialize=True,save_name=save_name)
        self.batch_size=batch_size
        self.max_epoch=max_epoch
        self.save_round_inteval=save_round_inteval
        self.save_epoch_inteval=save_epoch_inteval

    @staticmethod
    def get_header(evalObjs):
        l=["epoch","round","weight_norm","grad_norm"]
        for key in evalObjs:
            l=l+evalObjs[key].get_header()
        return l

    def train(self,train_dataset,modelSetup:model_setup.ModelSetup,
          optimizer,criterion=nn.CrossEntropyLoss()):
        batch_size=self.batch_size
        batch_per_epoch=np.ceil(len(train_dataset)/batch_size)

        dataloader=DataLoader(train_dataset,batch_size=batch_size,shuffle=False)

        for epoch in range(self.max_epoch):
            for batch_idx, (data, target) in enumerate(dataloader):
                flag=self.get_save_flag(epoch,batch_per_epoch,batch_idx)
                self.save_model_stat(modelSetup,flag)
                input_args=(data,)
                output_args=(target,)
                grad_tuple=modelSetup.calcuJacobian(input_args,criterion,output_args)
                self.save_grad_norm(grad_tuple,flag)
                optimizer.stepByGradTuple(grad_tuple)
                self.saveObj.save_cur_row()
                
    def save_grad_norm(self,grad_tuple,flag):
        if flag:
            self.saveObj.add_data_to_row({"grad_norm":self.get_grad_norm(grad_tuple)})

    def save_model_stat(self,modelSetup,flag):
        if flag:
            self.saveObj.add_data_to_row({"weight_norm":self.get_weight_norm(modelSetup)})
            for key in ["train","test"]:
                r=self.evalObjs[key].get(modelSetup)
                self.saveObj.add_data_to_row(r)

    def get_save_flag(self,epoch,batch_per_epoch,batch_idx):
        round=batch_idx+epoch*batch_per_epoch
        flag=False
        if self.save_round_inteval is not None:
            if round%self.save_round_inteval==0:
                flag=True
        if (self.save_epoch_inteval is not None) & (batch_idx==0):
            if epoch%self.save_epoch_inteval==0:
                flag=True
        if flag:
            self.saveObj.add_data_to_row({"epoch":epoch,"round":round})
        return flag


    @staticmethod
    def get_weight_norm(modelSetup):
        weight_norm=modelSetup.gettWeightNorm().item()
        return weight_norm

    @staticmethod
    def get_grad_norm(grad_tuple):
        grad_norm=model_setup.ModelPara(grad_tuple).norm().item()
        return grad_norm
    
class EvalWithDataset():
    def __init__(self,dataset,dataset_name,criterions:dict,
          batch_size):
        self.dataset=dataset
        self.dataset_name=dataset_name
        self.criterions=criterions
        self.batch_size=batch_size

    @staticmethod
    def getModelDatasetCriterion(modelSetup:model_setup.ModelSetup,criterion,dataset,
         batch_size):

        value=modelSetup.getDatasetCriterion(model_setup.Argset(dataset),criterion,batch_size,reduction="mean")

        return value
    
    @staticmethod
    def getComposedCriterion(criterions):
        cri_names=list(criterions.keys())
        def composedCriterion(out,target):
            r=np.zeros(len(criterions))
            for i,name in enumerate(cri_names):
                criterion=criterions[name]
                r[i]=criterion(out,target).item()
            return r
        return composedCriterion,cri_names


    def get(self,modelSetup:model_setup.ModelSetup):
        composedCriterion,cri_names=self.getComposedCriterion(self.criterions)
        r_dict={}
        temp=self.getModelDatasetCriterion(modelSetup,composedCriterion,self.dataset,
        self.batch_size)
        for i,cri_name in enumerate(cri_names):
            r_dict[f"{cri_name}_{self.dataset_name}"]=temp[i]

        return r_dict
    
    def get_header(self):
        return [f"{cri_name}_{self.dataset_name}" for cri_name in self.criterions]
    
    @staticmethod
    def get_line(df,x_name,cri_name,dataset_name):
        line=save_log.Line(df.loc[:,x_name],df.loc[:,f"{cri_name}_{dataset_name}"])
        return line
    
    

class SaveTrain():
    def __init__(self,save_folder,header=None,initialize=True,save_name="train"):
        self.save_folder=save_folder
        self.save_file=f"{save_folder}/{save_name}.csv"
        self.saveCsvObj=save_log.SaveCsvHeader(self.save_file,header)
        if initialize:
            self.saveCsvObj.initialize()

        self.cur_row={}

    def add_data_to_row(self,data_dict):
        self.cur_row.update(data_dict)
    def save_cur_row(self):
        if len(self.cur_row)>0:
            self.saveCsvObj.append_data([self.cur_row])
            self.cur_row={}

    def save(self,**kwargs):
        self.saveCsvObj.append_data([kwargs])

    def read_to_df(self):
        return self.saveCsvObj.read_to_df()
    
    @staticmethod
    def get_line(df:pd.DataFrame,col_name,x_name="round"):
        return save_log.Line(df.loc[:,x_name],df.loc[:,col_name])

class PlotEval():
    def __init__(self):
        self.criterions={
            "loss":nn.CrossEntropyLoss(),
            "acc":criterion.accuracy(reduction="mean"),
            "out_norm":criterion.out_norm
        }
        self.criterions_unreduced={
            "loss":nn.CrossEntropyLoss(reduction="none"),
            "acc":criterion.accuracy(reduction="none"),
            # "out_norm":criterion.out_norm
        }
        self.dataset_name=["train","test"]

    def get_evalObjs(self,train_dataset,test_dataset,batch_size):
        dataset_dict={"train":train_dataset,
                        "test":test_dataset,}
        
        evalObjs={}
        for dname, dset in dataset_dict.items():
            evalObjs[dname]=EvalWithDataset(dset,dname,self.criterions,batch_size)
        return evalObjs
    
    @staticmethod
    def get_df(save_folder):
        saveObj=SaveTrain(save_folder,initialize=False)
        df=saveObj.read_to_df()
        return df
    
    def get_figs(self,save_folder):
        x_name="round"
        figs={"loss":{},"acc":{},"out_norm":{},
            "weight_norm":{},
            "grad_norm":{}}
        df=self.get_df(save_folder)
        for cri_name in self.criterions:
            lines={}
            for dataset_name in self.dataset_name:
                line_name=dataset_name
                lines[line_name]=EvalWithDataset.get_line(df,x_name,cri_name,dataset_name)
            figs[cri_name]=lines

        for fig_name in ["weight_norm","grad_norm"]:
            figs[fig_name][0]=SaveTrain.get_line(df,fig_name,x_name)
        return figs

    def plot(self,save_folder):
        figs=self.get_figs(save_folder)
        for name in figs:
            save_log.plot1(figs[name],f"{save_folder}/{name}.png")