from utils.trainer import Trainer
from utils.helper import Save_Handle, AverageMeter
import os
import sys
import time
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
import logging
import numpy as np
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

#from models.aspd_spatial_uq1 import New_bay_Net
#from models.SI_INR import New_bay_Net
from models.SI_INR_new3 import New_bay_Net
#from models.SI_INR_PSGC import New_bay_Net

#from datasets.crowd_unic_car import Crowd
from datasets.crowd import Crowd
#from losses.bay_loss_new import Bay_Loss
from losses.bay_loss import Bay_Loss
from losses.post_prob_duo import Post_Prob

#import wandb
import random

from torch.optim import lr_scheduler

import cv2
from matplotlib import pyplot as plt

import torchvision.transforms.functional as F


from scipy.stats import norm



def seed_worker(worker_id):
    #worker_seed = torch.initial_seed() % 2**32
    #np.random.seed(worker_seed)
    #random.seed(worker_seed)
    pass
    
    
def get_parameters_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return total_num, trainable_num


def train_collate(batch):

    ##### img, sr_img, keypoints, prior_prob, st_size, grid_c, grid_sr, gridnum_sam_c

    transposed_batch = list(zip(*batch))
    
    images = torch.stack(transposed_batch[0], 0) ####img
    
    
    points = transposed_batch[1]  # keypoints, the number of points is not fixed, keep it as a list of tensor
    
    targets = transposed_batch[2]
    
    
    st_sizes = torch.FloatTensor(transposed_batch[3]) ###st_size  shortest size = min(w,h)
    
 
    grid_c = torch.stack(transposed_batch[4], 0)
    
    
    gridnum_sam_c = transposed_batch[5] ####gridnum_sam_c
    
    gd_count = transposed_batch[6]
    
    
    
    return images, points, targets, st_sizes, grid_c, gridnum_sam_c, gd_count


class RegTrainer(Trainer):
    def setup(self):
        """initial the datasets, model, loss and optimizer"""
        
        ###Initial
        
        
        
        
        using_method = 'bay_vae'
        args = self.args
        

        
        
        logging.info('using seed {}'.format(args.seed))
        logging.info('using method {}'.format(using_method))
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            # for code conciseness, we release the single gpu version
            assert self.device_count == 1
            logging.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        self.downsample_ratio = args.downsample_ratio
        
        ###initial building dataset
        self.datasets = {x: Crowd((os.path.join(args.data_dir, 'train_data/images') if x == 'train' else os.path.join(args.data_dir, 'test_data/images')),
                                  args.crop_size,
                                  args.downsample_ratio,
                                  args.is_gray, x) for x in ['train', 'val']}

        
        g = torch.Generator()
        g.manual_seed(args.seed)
        
        self.dataloaders = {x: DataLoader(self.datasets[x],
                                          collate_fn=(train_collate
                                                      if x == 'train' else default_collate),
                                          batch_size=(args.batch_size
                                          if x == 'train' else 1),
                                          shuffle=(True if x == 'train' else False),
                                          num_workers=args.num_workers*self.device_count,
                                          pin_memory=(True if x == 'train' else False),
                                          #worker_init_fn=(seed_worker if x == 'train' else None),
                                          #worker_init_fn=None,
                                          #generator=(g if x == 'train' else None),
                                          )
                            for x in ['train', 'val']}
        
        
        #####initial model
        
        #self.model = torch.load(args.resume)
        self.model =New_bay_Net(self.downsample_ratio, args.crop_size)
        #self.model =ASRNet()
        self.model.to(self.device)
        
        self.use_sr = args.use_sr
        
        
        total_num, trainable_num = get_parameters_number(self.model)
        print(total_num, trainable_num)
        total_num, trainable_num = get_parameters_number(self.model.modelA)
        print(total_num, trainable_num)
        total_num, trainable_num = get_parameters_number(self.model.Encoder2z)
        print(total_num, trainable_num)
        total_num, trainable_num = get_parameters_number(self.model.cc_decoder)
        print(total_num, trainable_num)
        

        
        
        #####initial optimizer, selct parameters' learning rate
        

        
        #self.optimizer = optim.Adam(training_params)
        #c_params = list(map(id, self.model.cc_decoder.last2.parameters()))
        #b_params = filter(lambda p: id(p) not in c_params, self.model.parameters())
        
        c_params = list(map(id, self.model.modelA.parameters()))
        b_params = filter(lambda p: id(p) not in c_params, self.model.parameters())
        
        
        self.optimizer1 = optim.Adam(b_params, lr=args.lr, weight_decay=args.weight_decay)
        #self.optimizer1 = optim.Adam(b_params, lr=args.lr)
        #self.optimizer2 = optim.Adam(self.model.cc_decoder.last2.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.optimizer2 = optim.Adam(self.model.modelA.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        
        
        #self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.scheduler = lr_scheduler.StepLR(self.optimizer1, step_size = 3000, gamma = 1)

  
        self.start_epoch = 0
        self.epoch = self.start_epoch
        ######whether to load pre-trained model
        
        if args.resume:
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer1.load_state_dict(checkpoint['optimizer_state_dict1'])
                self.optimizer2.load_state_dict(checkpoint['optimizer_state_dict2'])
                
                #self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume, self.device))
            elif suf == 'pt':
                #self.model.load_state_dict(torch.load(args.resume, self.device))
                pass
                
        ###initial bayesian loss, post probs

        self.post_prob = Post_Prob(args.sigma,
                                   args.crop_size,
                                   args.downsample_ratio,
                                   args.background_ratio,
                                   args.use_background,
                                   self.device)
        self.criterion = Bay_Loss(args.use_background, self.device)#####bayes loss
        self.save_list = Save_Handle(max_num=args.max_model_num)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_count = 0
        
        #init wandb building
        #self.run_db = wandb.init(project="Count2", name="b_bay_vae_256_123_no_sum",entity="yyuanx",reinit=True,config={"batch_size": args.batch_size,"lr": args.lr})
        #self.run_db = wandb.init(project="Count2", name=using_method + "_256_"+str(args.seed),entity="yyuanx",reinit=True,config={"batch_size": args.batch_size,"lr": args.lr})

    def train(self):
        """training process"""
        args = self.args
        features1 = []
        features2 = []
        for epoch in range(self.start_epoch, args.max_epoch):
        
            
        
            logging.info('-'*5 + 'Epoch {}/{}'.format(epoch, args.max_epoch - 1) + '-'*5)
            self.epoch = epoch
            self.train_eopch(self.epoch)
            #print(f1.shape)
            self.scheduler.step()
            
            if epoch % args.val_epoch == 0 and epoch >= args.val_start:
                self.val_epoch()

        #wandb finish
        #self.run_db.finish()

    def train_eopch(self, epoch):
        epoch_loss = AverageMeter()
        epoch_mae = AverageMeter()
        epoch_mse = AverageMeter()
        epoch_start = time.time()
        
        ##### Set model to training mode
        
        self.model.train()  
        
        #print("this is ms, vs0.8")

        # Iterate over data.
        for step, (inputs, points, targets, st_sizes, grid_c, gridnum_sam_c, gd_count) in enumerate(self.dataloaders['train']): 
            #print(inputs.shape, sr_gt.shape)#3 256 256
            #print(targets.shape, st_sizes)
            
            ###set data to device
            _scale = 1 + (random.random()-0.5)*0.8 -0.3 ###0.8 0.6 1.0
            _h = round(512 * _scale)
            _w = round(512 * _scale)
            #inputs = F.resize(inputs, _w)
            #print(inputs.shape)
            inputs = inputs.to(self.device) ###img
            points = [p.to(self.device) for p in points] ###points
            #prior_prob = [t.to(self.device) for t in prior_prob] ###prior_prob
            targets = [t.to(self.device) for t in targets]
            st_sizes = st_sizes.to(self.device)
            
            grid_c = grid_c.to(self.device)
            
            
            gridnum_sam_c = [tt.to(self.device) for tt in gridnum_sam_c]
            

            
            
            ###iteration
            with torch.set_grad_enabled(True):
            
                ###run the model
                use_sr = False

                outputs, discrete_density, out_sigma = self.model(inputs, grid_c, 'train', epoch)   
                #print(inputs.shape, outputs.shape)

                  
                ###bayesian loss
                prob_list = self.post_prob(points, st_sizes, gridnum_sam_c)


                #loss1 = self.criterion(prob_list, prior_prob, outputs, discrete_density, epoch) 
                
                loss1 = self.criterion(prob_list, targets, outputs)
                
                loss_KL = self.model.kl_div
                
                
                outputs = outputs/10
                
                
                Train_uq = False
                if Train_uq:
                  print('use uq')
                
                  loss = 1.0*loss1 + 0.1*loss_KL
                  
                  self.optimizer1.zero_grad()
                  self.optimizer2.zero_grad()
                  loss.backward()
                  self.optimizer2.step()
                
                else:
                
                  loss = 1.0*loss1#+ 0.1*loss_KL
                  if True:#epoch<20 or epoch%5==0:
                      self.optimizer1.zero_grad()
                      self.optimizer2.zero_grad()
                    
                      loss.backward()
                      self.optimizer1.step()
                      self.optimizer2.step()
                  else:
                      self.optimizer1.zero_grad()
                      #self.optimizer2.zero_grad()
                      loss.backward()
                      self.optimizer1.step()
                  
                #self.run_db.log({"iter_Loss": loss.item(), "c_Loss": loss1.item()})
                
                #for p in self.model.resnet_backbone.frontend2.parameters():### loss1: 10e-2, loss2: >10e2
                #    if p is not None:
                #        print(p.grad)
                
                ####update the metrics
                N = inputs.size(0)
                pre_count = torch.sum(outputs.view(N, -1), dim=1).detach().cpu().numpy()
                #print(pre_count)
                #print(pre_count.shape, gd_count.shape)
                
                res = pre_count - gd_count

                #print(np.sum(f1),np.sum(f2))
                

                #print(pre_count, gd_count)
                epoch_loss.update(loss.item(), N)
                epoch_mse.update(np.mean(res * res), N)
                epoch_mae.update(np.mean(abs(res)), N)


        
        
        logging.info('Epoch {} Train, Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
                     .format(self.epoch, epoch_loss.get_avg(), np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(),
                             time.time()-epoch_start))
        model_state_dic = self.model.state_dict()
        save_path = os.path.join(self.save_dir, '{}_ckpt.tar'.format(self.epoch))
        torch.save({
            'epoch': self.epoch,
            #'optimizer_state_dict': self.optimizer.state_dict(),
            'optimizer_state_dict1': self.optimizer1.state_dict(),
            'optimizer_state_dict2': self.optimizer2.state_dict(),
            'model_state_dict': model_state_dic
        }, save_path)
        self.save_list.append(save_path)  # control the number of saved models
        
        

    def val_epoch(self):
        
        epoch_start = time.time()
        self.model.eval()  # Set model to evaluate mode
        epoch_res = []
        
        sr_results = []
        c_results = []
        
        
        #scale_ls = [1.0, 0.4, 0.6, 0.8, 1.2]
        scale_ls = [1.0]
        
        sigmas = [[] for i in range(len(scale_ls))]
        res_ls = [[] for i in range(len(scale_ls))]
        
        
        # Iterate over data.
        for inputs, count, name, cor_C in self.dataloaders['val']:
            c_sub_result = []
            epoch_sub_res = []
            #for scale in scale_ls:
            for i in range(len(scale_ls)):
                #print('val input',inputs.shape)
                scale = scale_ls[i]
                _w = round(512 * scale)
                inputs_ = F.resize(inputs, _w)
                #inputs_ = inputs
                inputs_ = inputs_.to(self.device)
            
                inputs = inputs.to(self.device)
                cor_C = cor_C.to(self.device)
                # inputs are images with different sizes
                assert inputs.size(0) == 1, 'the batch size should equal to 1 in validation mode'
                with torch.set_grad_enabled(False):
                #print(cor_C)
                
                    #print(inputs_.shape, cor_C.shape)
                    outputs, tttt, out_sigma = self.model(inputs_, cor_C, 'test',self.epoch)
                    
                    #mid_point = int(256*scale)
                    
                    #outputs1, out_sigma = self.model(inputs[:,:,:256,:256], cor_C, 'test',self.epoch)
                    #outputs2, out_sigma = self.model(inputs[:,:,:256,256:], cor_C, 'test',self.epoch)
                    #outputs3, out_sigma = self.model(inputs[:,:,256:,:256], cor_C, 'test',self.epoch)
                    #outputs4, out_sigma = self.model(inputs[:,:,256:,256:], cor_C, 'test',self.epoch)
                    
                    #outputs1, tttt, out_sigma = self.model(inputs_[:,:,:mid_point,:mid_point], cor_C, 'test',self.epoch)
                    #outputs2, _, out_sigma = self.model(inputs_[:,:,:mid_point,mid_point:], cor_C, 'test',self.epoch)
                    #outputs3, _, out_sigma = self.model(inputs_[:,:,mid_point:,:mid_point], cor_C, 'test',self.epoch)
                    #outputs4, _, out_sigma = self.model(inputs_[:,:,mid_point:,mid_point:], cor_C, 'test',self.epoch)
                    
                    

                    outputs = outputs/10
                    #outputs = (outputs1 + outputs2 +outputs3 + outputs4)/10
                    #res = count[0].item() - torch.sum(outputs1/10).item() - torch.sum(outputs2/10).item() - torch.sum(outputs3/10).item() - torch.sum(outputs4/10).item()
                    #res = count[0].item() - torch.sum(outputs1).item() - torch.sum(outputs2).item() - torch.sum(outputs3).item() - torch.sum(outputs4).item()

                    tttt = tttt/10
                    res = count[0].item() - torch.sum(outputs).item()
                    #print(torch.sum(outputs).item(),count[0].item())
                    #if scale_ls[i] == 1:
                    #    epoch_res.append(res)
                    epoch_sub_res.append(res)
                    c_sub_result.append(outputs.data.cpu().numpy())
                    #c_sub_result.append(tttt.data.cpu().numpy())

                    
                    sigmas[i].append(torch.sqrt(torch.sum(torch.pow(out_sigma,2))).data.cpu().numpy())
                    res_ls[i].append(res)

            c_results.append(c_sub_result) 
            epoch_res.append(epoch_sub_res)
        #print(torch.sqrt(torch.sum(0.01*torch.pow(out_sigma,2))).data.cpu().numpy())

        epoch_res = np.array(epoch_res)
        #mse = np.sqrt(np.mean(np.square(epoch_res)))
        #mae = np.mean(np.abs(epoch_res))
        mses = np.sqrt(np.mean(np.square(epoch_res), axis = 0, keepdims=True))
        maes = np.mean(np.abs(epoch_res),axis = 0, keepdims=True)
        print('MAE', maes, np.mean(maes))
        print('MSE', mses, np.mean(mses))
        #mae = maes[0][0]
        #mse = mses[0][0]
        mae = np.mean(maes)
        mse = np.mean(mses)
       
        #np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg()
        logging.info('Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
                     .format(self.epoch, mse, mae, time.time()-epoch_start))
                     
        
        #self.run_db.log({"* MAE": mae, "* MSE": mse})
        
        model_state_dic = self.model.state_dict()
        #print(model_state_dic)
        if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae):
        #if mae < self.best_mae:
        #if True:
            self.best_mse = mse
            self.best_mae = mae
            logging.info("save best mse {:.2f} mae {:.2f} model epoch {}".format(self.best_mse,
                                                                                 self.best_mae,
                                                                                 self.epoch))
            #torch.save(model_state_dic, os.path.join(self.save_dir, 'best_model.pth'))
            torch.save(self.model, os.path.join(self.save_dir, 'best_model.pt'))
            
            #sr_results = np.array(sr_results)
            #c_results = np.array(c_results)
            #np.save('srr_image.npy', sr_results)
            #np.save('siinr_out_se64.npy', c_results)
        logging.info("best mse {:.2f} mae {:.2f} model epoch {}".format(self.best_mse,
                                                                                 self.best_mae,
                                                                                 self.epoch))
        print(self.best_mae, self.best_mse)
        #if mae < 7.5:
            #torch.save(model_state_dic, os.path.join(self.save_dir, 'recent_model.pth'))
            #torch.save(self.model, os.path.join(self.save_dir, 'recent_model.pt'))
        
        #if mae < 7.2:
            #torch.save(model_state_dic, os.path.join(self.save_dir, 'recent_model.pth'))
            #torch.save(self.model, os.path.join(self.save_dir, str(self.epoch)+'fine_model.pt'))
            #torch.save(self.model, os.path.join(self.save_dir, 'fine_model.pt'))
            
        if abs(mae-self.best_mae) < 10:
            #torch.save(model_state_dic, os.path.join(self.save_dir, 'recent_model.pth'))
            torch.save(self.model, os.path.join(self.save_dir, 'recent_model.pt'))
        
        if abs(mae-self.best_mae) < 5:
            #torch.save(model_state_dic, os.path.join(self.save_dir, 'recent_model.pth'))
            #torch.save(self.model, os.path.join(self.save_dir, str(self.epoch)+'fine_model.pt'))
            torch.save(self.model, os.path.join(self.save_dir, 'fine_model.pt'))
            
            
        #c_results = np.array(c_results)  
        '''
        if self.epoch %10 == 0:
            fig = plt.figure(figsize=(10, 7))
            rows = 1
            columns = 4
            for i in range(4):
                #c_example = np.reshape(c_results[i], [32,32])
                c_example = np.reshape(c_results[i], [32,32])
                fig.add_subplot(rows, columns, i+1)    
                plt.imshow(c_example)
            plt.savefig(os.path.join(self.save_dir, str(self.epoch)+'_test_fig.png'))
            plt.close()
        '''    
        if self.epoch %10 == 0:
            fig = plt.figure()
            rows = 4
            columns = len(scale_ls)
            for i in range(4):
                for j in range(len(scale_ls)):
                    #c_example = np.reshape(c_results[i], [32,32])
                    #c_example = np.reshape(c_results[i][j], [64,64])
                    
                    #c_example = np.reshape(c_results[i][j], [128,128])
                    #c_example = c_results[i][j][0][0]
                    #print(c_example.shape)
                    #c_example = cv2.resize(c_example, (64, 64), interpolation=cv2.INTER_CUBIC)
                    c_example = c_results[i][j]
                    c_example = cv2.resize(c_example[0][0], (64, 64), interpolation=cv2.INTER_CUBIC)
                    
                    
                    
                    fig.add_subplot(rows, columns, i*len(scale_ls)+1+j)
                    plt.imshow(c_example)
            plt.savefig(os.path.join(self.save_dir, str(self.epoch)+'_test_fig.png'))
            plt.close()
            
            
            

        
        
        
        



