import torch
import numpy as np
import copy
import os
import torch.nn.functional as F
from operator import itemgetter
import datetime
from tqdm import tqdm
import metrics
from utils import *
from model.util import *

import copy

class Trainer(object):
    
    def __init__(self,
                 args,
                 model,
                 stage,
                 adj_set,
                optimizer_lr
               ):
        
        self.args = args
        self.model_name = args.model_name
        self.model = model.cuda()
        self.optimizer = torch.optim.Adam(self.model.parameters(),lr = optimizer_lr)
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size = 5, gamma = 0.5,verbose=True)
        # self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,factor = 0.3,mode='min', verbose=True)

        self.epochs = args.training_epoch
        self.method = args.selection_method
        self.current_stage_index = stage
        self.adj_set = adj_set
####
        self.model_previous_stage = None
####


    def train_epoch(self,train_dataloader,buffer,raw_data):

        
        train_loss_epoch_records_ts_stage = []
        train_loss_epoch_records_graph_stage = []
        train_loss_epoch_records_ts_mem = []
        train_loss_epoch_records_graph_mem = []

        mi_histolic_stage = []
        mi_histolic_mem = []
        mi_distill = []
       
        
       
        for i,data in enumerate(train_dataloader):

            ind_stage,x_stage,y_stage,x_norm_stage,y_norm_stage = data
            x_norm_stage = x_norm_stage.float().cuda()
            y_norm_stage = y_norm_stage.float().cuda()

            # x.shape = batch_size, num_nodes, sequence_length

            

            if self.model_name[:6] == 'ski-cl':
                y_hat,h = self.model(x_norm_stage,prior_form = self.args.prior_form)
                y_hat = torch.squeeze(y_hat)
                structure_pred = self.model.graph_generator.sample(x_norm_stage,prior_form = self.args.prior_form)
                structure_true=torch.from_numpy(np.stack(itemgetter(*ind_stage.tolist())(self.adj_set),axis=0)).float().cuda()

            
            

               
                


            ts_loss_stage =  metrics.masked_mse(y_hat,y_norm_stage)
            train_loss_epoch_records_ts_stage.append(ts_loss_stage.item())


            loss = self.args.alpha_ts_stage*ts_loss_stage

          
        
            if self.model_name[:6] == 'ski-cl':

                batch_size,_,_ = structure_pred.shape
                structure_pred = structure_pred.view(batch_size,-1)
                structure_true = structure_true.view(batch_size,-1)

                if self.args.prior_form  == 'binary': 
                    loss_graph = F.binary_cross_entropy(structure_pred,structure_true,reduction = 'mean')
                else: 
                    loss_graph = metrics.masked_mse(structure_pred,structure_true)

                train_loss_epoch_records_graph_stage.append(loss_graph.item())

                loss += self.args.graph_coef*loss_graph

            
           
            

            
            if self.args.selection_method[:5] != 'joint' and self.args.selection_method[:4] != 'none' and  buffer.memory_size>0:   # only replay, joint and seq do not consider
                
                ind_mem,x_mem,y_mem,x_norm_mem,y_norm_mem = buffer.sample(self.args.batch_size)
               

                if self.model_name[:6] == 'ski-cl':
                    y_hat_mem,h = self.model(x_norm_mem,prior_form = self.args.prior_form)
                    y_hat_mem = torch.squeeze(y_hat_mem)
                    structure_pred_mem = self.model.graph_generator.sample(x_norm_mem,prior_form = self.args.prior_form)
                    structure_true_mem=torch.from_numpy(np.stack(itemgetter(*ind_mem.tolist())(self.adj_set),axis=0)).float().cuda()


                
                    


                ts_loss_mem =  metrics.masked_mse(y_hat_mem,y_norm_mem)
                train_loss_epoch_records_ts_mem.append(ts_loss_mem.item())

                loss += self.args.alpha_ts_mem*ts_loss_mem

                if self.model_name[:6] == 'ski-cl':
                    batch_size,_,_ = structure_pred_mem.shape
                    structure_pred_mem = structure_pred_mem.view(batch_size,-1)
                    structure_true_mem = structure_true_mem.view(batch_size,-1)

                    if self.args.prior_form  == 'binary': 
                        loss_graph_mem = F.binary_cross_entropy(structure_pred_mem,structure_true_mem,reduction = 'mean')
                    else: 
                        loss_graph_mem = metrics.masked_mse(structure_pred_mem,structure_true_mem)

                   
                    train_loss_epoch_records_graph_mem.append(loss_graph_mem.item())
                    loss +=  self.args.graph_coef*loss_graph_mem


                

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=1, norm_type=2)
            self.optimizer.step()

           


        return train_loss_epoch_records_ts_stage,train_loss_epoch_records_ts_mem,mi_histolic_stage,mi_histolic_mem,mi_distill,train_loss_epoch_records_graph_stage,train_loss_epoch_records_graph_mem
                   
            
          

    
    def val_epoch(self,val_dataloader,raw_data):

        val_loss_epoch_records_ts_stage = []
        
        for i,data in enumerate(val_dataloader):

            ind_stage,x_stage,y_stage,x_norm_stage,y_norm_stage = data
            x_norm_stage = x_norm_stage.float().cuda()
            y_norm_stage = y_norm_stage.float().cuda()
            
           
            
            if self.model_name[:6] == 'pki-cl':
                y_hat,h = self.model(x_norm_stage,prior_form = self.args.prior_form)
                y_hat = torch.squeeze(y_hat)
            
            ts_loss_val =  metrics.masked_mse(y_hat,y_norm_stage)
            val_loss_epoch_records_ts_stage.append(ts_loss_val.item())
        
        return val_loss_epoch_records_ts_stage



    
        

            


      
      
    def train(self,train_dataloader,val_dataloader,buffer,raw_data):
        
        
        best_val_loss = float('inf')
        
        
        for epoch in range(self.epochs):
            

            self.model.train()
            ts_train_loss_stage_list,ts_train_loss_mem_list,mi_histolic_stage,mi_histolic_mem,mi_distill,train_loss_epoch_records_graph_stage,train_loss_epoch_records_graph_mem = self.train_epoch(train_dataloader,buffer,raw_data)

            ts_train_loss = np.mean(ts_train_loss_stage_list)

    
            self.model.eval()
            with torch.no_grad():
                ts_val_loss_stage_list = self.val_epoch(val_dataloader,raw_data)
            
            ts_val_loss = np.mean(ts_val_loss_stage_list)
          
            self.lr_scheduler.step()
            
            
            if ts_val_loss<best_val_loss:

                best_val_loss = ts_val_loss 
                not_improved_count = 0
                self.save_model(epoch)
            
            
            else:
                not_improved_count += 1

            
            if epoch%3 == 0:
                print("Epoch :{} train MSE: {:.4f}  validation MSE: {:.4f}  Graph current: {:.4f}  Graph mem: {:.4f} MI distill: {:.4f} ".format(epoch,ts_train_loss,ts_val_loss,
                                                                                                                                           np.mean(train_loss_epoch_records_graph_stage),
                                                                                                                                           np.mean(train_loss_epoch_records_graph_mem),
                                                                                                                                           np.mean(mi_distill)))

            if not_improved_count >= 5:
                print('early stopping')
                break
    

        
    def test(self,results,test_loader_stage,scaler_stage,data_name_stage,train_stage_index,test_stage_index,raw_data):


    # record time_series testing performance 

        test_rmse_loss = []
        test_mse_loss = []
        test_mae_loss = []
        test_mape_loss = []

        test_recall_list = []
        test_precision_list = []
        test_graph_mae = []
        test_graph_rmse = []
    
   
        
        path = "model_check_point/{}/{}/{}/best_model.pth".format(self.model_name,self.method,data_name_stage)
        check_point = torch.load(path)
        state_dict = check_point['state_dict']
        self.model.load_state_dict(state_dict)
        self.model = self.model.cuda()
    


        self.model.eval()
        with torch.no_grad():
           for i,data in enumerate(test_loader_stage):

            ind_stage,x_stage,y_stage,x_norm_stage,y_norm_stage = data

            x_norm_stage = x_norm_stage.float().cuda()
            y_stage = y_stage.float().cuda()

            
            
            if self.model_name[:6] == 'pki-cl':
                y_hat,h = self.model(x_norm_stage,prior_form = self.args.prior_form)
                y_hat = torch.squeeze(y_hat)
            
                pred = self.model.graph_generator.sample(x_norm_stage,self.args.prior_form)
                batch_size = pred.shape[0]
                if self.args.prior_form == 'binary':
                    for b in range(batch_size):
                        pre = pred[b,:,:]
                            # if partial  == 'True':
                            #     mask = torch.from_numpy(adj_year).bool().cuda()  
                            #     pre = torch.where(mask, pre, torch.tensor(0.0))

                        recall, precision,err = calc_matrics(matrix = self.adj_set[test_stage_index+1], matrix_pred = pre)
                        test_recall_list.append(recall)
                        test_precision_list.append(precision)
                    
                elif self.args.prior_form == 'continuous':
                    true_label=torch.from_numpy(np.stack(itemgetter(*ind_stage.tolist())(self.adj_set),axis=0)).float().cuda()
                    test_graph_mae.append(metrics.masked_mae(pred,true_label).item())
                    test_graph_rmse.append(np.sqrt(metrics.masked_mse(pred,true_label).item()))
            




            mse_loss = metrics.masked_mse(scaler_stage.inverse_transform(y_hat), y_stage)
            mae_loss = metrics.masked_mae(scaler_stage.inverse_transform(y_hat), y_stage)
            mape_loss = metrics.masked_mape_np(scaler_stage.inverse_transform(y_hat).cpu().numpy(),y_stage.cpu().numpy())
            
            test_rmse_loss.append(torch.sqrt(mse_loss).item())
            test_mae_loss.append(mae_loss.item())
            test_mape_loss.append(mape_loss.item())
            test_mse_loss.append(mse_loss.item())

        
        
        rmse = np.mean(test_rmse_loss)
        mse = np.mean(test_mse_loss)
        mae = np.mean(test_mae_loss)
        mape = np.mean(test_mape_loss)

        if self.model_name[:6] == 'ski-cl':

            if self.args.prior_form == 'binary':
                results['recall'][train_stage_index,test_stage_index] = np.mean(test_recall_list)
                results['precision'][train_stage_index,test_stage_index] = np.mean(test_precision_list)
            
            elif self.args.prior_form == 'continuous':
                results['graph_mae'][train_stage_index,test_stage_index] = np.mean(test_graph_mae)
                results['graph_rmse'][train_stage_index,test_stage_index] = np.mean(test_graph_rmse)
            


        results['rmse'][train_stage_index,test_stage_index] = rmse
        results['mse'][train_stage_index,test_stage_index] = mse
        results['mae'][train_stage_index,test_stage_index] = mae
        results['mape'][train_stage_index,test_stage_index] = mape

        print("Testing RMSE: {:.4f} Testing MSE: {:.4f} MAE: {:.4f} MAPE: {:.4f} ".format(rmse,mse,mae,mape))

        if self.model_name[:6] == 'ski-cl':
            if self.args.prior_form == 'binary':
                print("Testing Graph Precision: {:.4f}  Recall: {:.4f} ".format(np.mean(test_precision_list),np.mean(test_recall_list)))

            elif self.args.prior_form == 'continuous':

                print("Testing Graph MAE: {:.4f}  RMSE: {:.4f} ".format(np.mean(test_graph_mae),np.mean(test_graph_rmse)))
    




    


    def freeze_model(self,models):
        for param in models.parameters():
            param.requires_grad = False
        return        
        
    
    def save_model(self,epoch_num):

        state = {
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict()
                }

        dirs = "model_check_point/{}/{}/{}".format(self.model_name,self.method,self.args.data_name_stage)
                
        if not os.path.exists(dirs):
                os.makedirs(dirs)
                
        best_path = "model_check_point/{}/{}/{}/best_model.pth".format(self.model_name,self.method,self.args.data_name_stage)
        torch.save(state, best_path)
        print('Best_model_saved at epoch {}'.format(str(epoch_num)))


                

            
            
            

            







      

    
     
    
