from __future__ import absolute_import, division, print_function
from torch.utils.tensorboard import SummaryWriter
#from tensorboardX import SummaryWriter
import  matplotlib.lines as  mlines
import os
import matplotlib.pyplot as plt
import math
from tqdm import tqdm, trange
import time
import os
import pandas
import torch

from comp_Optimizer import comp_Optimizer
import networks
import datasets
import samplers
import actions
import VisualizationUtils

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

import argparse
import os

file_dir=os.path.dirname(__file__)

class TestTester:
    def __init__(self,options):
        self.opt = options
        #self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)
        self.device = 'cpu' if self.opt.no_cuda else 'cuda'
        print("Device: ", self.device) 
        self.sampler_dict = {'AffineSampler': samplers.AffineSampler,
                             'ParamSamplerCol': samplers.ParamSamplerCol}

        self.action_dict = {'AffineActions': actions.AffineActions,
                            'ParamActionsCol': actions.ParamActionsCol}

        self.classifier_dict = {"ResNet":networks.ResNet18,
                                "ResNet34":networks.ResNet34,
                                "smallnet":networks.smallnet,
                                "wideResNet":networks.WideResNet}
        
         
        self.writer = SummaryWriter(self.opt.vis_dir) 
        self.dataset_dict = {"MNIST": torchvision.datasets.MNIST,
                             "CIFAR10": torchvision.datasets.CIFAR10,
                             "CIFAR100": torchvision.datasets.CIFAR100,
                             "rotMNIST": datasets.rotMNIST,
                             "SVHN": torchvision.datasets.SVHN}


    def test(self):
        # fix num_groups/dist_dim before sampler. 
        device=self.device
        #torch.autograd.set_detect_anomaly(True)

        number_samples=self.opt.number_samples
      
        test_size = self.opt.test_size
             
        sampler = self.sampler_dict[self.opt.sampler](tau=0.3,device=self.device) # 3 !
        sampler = sampler.to(device) #fix that
        if(self.opt.set_const_group>-1):
           with torch.no_grad():
               new_theta=torch.zeros_like(sampler.pi.detach())
               exponent=torch.arange(new_theta.shape[1],device=device)
               new_theta[0]= torch.fmod(torch.div(self.opt.set_const_group,(2**exponent),
                                                rounding_mode="trunc"),2)
               #new_theta[0,self.opt.set_const_group]=1
               sampler.pi.copy_(new_theta)

        num_groups=sampler.dist_dim #preset fix that!                                                                                                          
        #data_thetas=torch.zeros(batch_size,num_groups,device=device)

        #mist_flag = 'cat' if self.opt.dataset=='Mnist' else 'bin'
        action = self.action_dict[self.opt.action](device=self.device)


        if (self.opt.dataset == "Mnist"):
           
           mean_normalizer = 0.1307
           std_normalizer = 0.3081
           transform = transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize((mean_normalizer,), (std_normalizer,))])
           
           test_dataset = self.dataset_dict["MNIST"](root='./datasets/',download=True,train=False,transform=transform)
           

           loss_cr = torch.nn.NLLLoss()
           action = self.action_dict[self.opt.action](device=self.device,image_size=28)


           
           net = self.classifier_dict[self.opt.classifier]().to(self.device)
           
        elif (self.opt.dataset == 'rotMNIST'):
           mean_normalizer = 0.1307
           std_normalizer = 0.3081
           print("Training on rotMnist")        
           dataset=self.dataset_dict[self.opt.dataset](mean_normalizer,std_normalizer,train=True).create_dataset()
           test_dataset=self.dataset_dict[self.opt.dataset](mean_normalizer,std_normalizer,train=False).create_dataset()

           loss_cr = torch.nn.NLLLoss()
           
           action = self.action_dict[self.opt.action](device=self.device,image_size=28)

           net = self.classifier_dict[self.opt.classifier]().to(self.device)
                      
        elif (self.opt.dataset== "SVHN"):
           mean_normalizer = 0.5
           std_normalizer = 0.5
           #transform = transforms.Compose([transforms.ToTensor(),
           transform_train = transforms.Compose([
                            #transforms.RandomCrop(32, padding=4),
                            #transforms.RandomHorizontalFlip(),
                            #transforms.RandomRotation(30),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                            ])

           transform_test = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                            ])

 
          
           test_dataset = self.dataset_dict["SVHN"](root='./datasets/',download=True,split='test',transform=transform_test)
           loss_cr = torch.nn.NLLLoss()
           
          
           net = self.classifier_dict[self.opt.classifier](10).to(self.device)
 
        elif (self.opt.dataset == "Cifar10"):
            
           mean_normalizer = 0.5
           std_normalizer = 0.5
       
           transform_train = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                            ])

           transform_test = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                            ])

 
           dataset = self.dataset_dict["CIFAR10"](root='./datasets/',download=True,train=True,transform=transform_train)
           
           test_dataset = self.dataset_dict["CIFAR10"](root='./datasets/',download=True,train=False,transform=transform_test)
           loss_cr = torch.nn.NLLLoss()


           net = self.classifier_dict[self.opt.classifier]().to(self.device)
        else :
           mean_normalizer = 0.5
           std_normalizer = 0.25
           #
           transform_train = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761)),
                            ])

           transform_test = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5070, 0.4865, 0.4409), (0.2673, 0.2564, 0.2761)),
                            ])

 
           dataset = self.dataset_dict["CIFAR100"](root='./datasets/',download=True,train=True,transform=transform_train)
           
           test_dataset = self.dataset_dict["CIFAR100"](root='./datasets/',download=True,train=False,transform=transform_test)
          
           loss_cr = torch.nn.NLLLoss()
           
          
           net = self.classifier_dict[self.opt.classifier](num_classes=100).to(self.device)
            
      
        test_loader = DataLoader(test_dataset,batch_size=test_size,shuffle=True,num_workers=4)
    
        if(os.path.isfile(self.opt.log_dir+'/saved_model')):
            #checkpoint=torch.load(self.opt.log_dir+'/firstCycleSaved_model')
            checkpoint=torch.load(self.opt.log_dir+'/saved_model')

            start_epoch=checkpoint['epoch']+1
            net.load_state_dict(checkpoint['model_params'])
            
            sampler.load_state_dict(checkpoint['sampler'])
        else:
            print("Model doesn't exist")

        net.eval()
        overall_dict={}
        for number_samples in range(0,10,2):
            print("####################################")
            print("Number samples= "+str(number_samples))
            ece_list=[]
            accuracy_list=[]
            bin_centers_list=[]
            y_axis_list=[]
            bin_width_list=[]
            for cali_times in range(5):
                correct=0
                correct_augm=0
                cnt_samples=0
                
                corr_list=torch.tensor([],device=self.device)
                conf_list=torch.tensor([],device=self.device)
                corr_list_augm=torch.tensor([],device=self.device)
                var_list=torch.tensor([],device=self.device)
                mean_var=torch.tensor([],device=self.device)
                exp_soft=torch.tensor([],device=self.device)


                with torch.no_grad(): 
                    for batch_id, (data,target) in tqdm(enumerate(test_loader)):
                        b_size=data.shape[0]
                        if(b_size!=self.opt.test_size):
                            continue
                        data, target= data.to(self.device), target.to(self.device)
                        batch_size=data.shape[0]
                        cnt_samples+=b_size
                        if number_samples==0:
                            output=torch.exp(net(data))
                            out_augm_prob=output

                            mean_augm_prob=out_augm_prob
                            std_augm_prob=out_augm_prob
                            
                        else: 
                            net(data[0].unsqueeze(0))
                            data_sample, _,_=sampler(number_samples*b_size)
                            data_sample = data_sample.detach()
                             
                            data_augm=action(data_sample,data)
                            data_augm = data_augm.flatten(start_dim=0,end_dim=1)
                          
                            output_augm = net(data_augm)
                            out_augm_prob = torch.exp(output_augm) #Assumes LogSoftmax Output
                            mean_augm_prob = out_augm_prob.reshape(number_samples, b_size, -1).mean(dim=0).squeeze(0)
                            std_augm_prob =  out_augm_prob.reshape(number_samples, b_size, -1).std(dim=0).squeeze(0)
                            
                        pred = mean_augm_prob.argmax(dim=1, keepdim=True)
                        
                        corr_augm=pred.eq(target.view_as(pred)).squeeze()
                        corr_list_augm=torch.cat((corr_list_augm,corr_augm))
                        ''' 
                        variences=torch.gather(std_augm_prob,1,pred.view(-1,1)).squeeze()
                        var_list=torch.cat((var_list,variences))
                        mean_var=torch.cat((mean_var,std_augm_prob.mean(dim=1)))
                        '''
                        expected_outs=torch.gather(mean_augm_prob,1,pred.view(-1,1)).squeeze()
                        exp_soft=torch.cat((exp_soft,expected_outs))

                        correct_augm += pred.eq(target.view_as(pred)).sum().item()
                print("Augmented Accuracy:"+str(correct_augm/cnt_samples))
               
                comb_list=[(exp_soft,corr_list_augm,"TAE_Soft_")]
                
                for tlist,tcorr,name in comb_list:
                    ece=0
                    
                    num_samples=0
                    min_v=0 # tlist.min()
                    max_v=1 #tlist.max()
                    bins_num=15
                    bins_bounds=torch.linspace(min_v,max_v,bins_num+1)
                    bin_width=bins_bounds[1]-bins_bounds[0]
                    bin_centers=bins_bounds[0:-1]+bin_width/2
                    #print("Name "+name)
                    #print("Min="+str(min_v)+" , Max="+str(max_v))
                    y_axis=np.zeros(bins_num)
                    for bin_id in range(bins_num):
                        correct_pred=tcorr[torch.logical_and(tlist>=bins_bounds[bin_id] , tlist<bins_bounds[bin_id+1])]
                        mean_conf=tlist[torch.logical_and(tlist>=bins_bounds[bin_id],tlist<bins_bounds[bin_id+1])].mean()

                        accur=0
                        if(correct_pred.shape[0]>0):
                            y_axis[bin_id]=correct_pred.sum().item()/correct_pred.shape[0]
                            accur=y_axis[bin_id]
                            ece+=torch.abs(mean_conf-accur).item()*correct_pred.shape[0]
                        num_samples+=correct_pred.shape[0]
                    ece=ece/num_samples
                    tot_acc=tcorr.sum().item()/tcorr.shape[0]
                    print("ECE="+str(ece))
                    print("ACC="+str(tot_acc))
                    ece_list.append(ece)
                    accuracy_list.append(tot_acc)
                    bin_centers_list.append(bin_centers)
                    y_axis_list.append(y_axis)
                    bin_width_list.append(bin_width)
            overall_dict[number_samples]={'ece':ece_list,
                                          'tot_acc':accuracy_list,
                                          'bin_centers': bin_centers_list,
                                          'y_axis': y_axis_list,
                                          'bin_width': bin_width_list}
            torch.save(overall_dict,self.opt.result)
            self.writer.add_scalar(name+"ECE",ece,number_samples)
            print("ECE="+str(np.mean(np.array(ece_list)))+" STD="+str(np.std(np.array(ece_list))))
            self.writer.add_scalar(name+"ACC",tot_acc,number_samples)
            print("ACC="+str(np.mean(np.array(accuracy_list)))+" STD="+str(np.std(np.array(accuracy_list))))
        self.writer.flush()

if __name__=="__main__":

    parser=argparse.ArgumentParser()
    parser.add_argument("--no_cuda",action="store_true")
    parser.add_argument("--result", type=str)
    parser.add_argument("--vis_dir",type=str)
    parser.add_argument("--number_samples",type=int,default=4)
    parser.add_argument("--naming",type=str) 
    parser.add_argument("--log_dir",type=str)
    parser.add_argument("--set_const_group",type=int,default=-1)
    parser.add_argument("--dataset",type=str,default="Cifar10",choices=["Mnist","Cifar10","Cifar100","rotMNIST","SVHN"])
    parser.add_argument("--sampler",type=str,default="AffineSampler",choices=['AffineSampler','ParamSamplerCol'])
    
    parser.add_argument("--action",type=str,default="AffineActions",choices=["AffineActions","ParamActionsCol"])
    parser.add_argument("--classifier",type=str,default="ResNet",choices=["ResNet","ResNet34","realResNet34","smallnet","wideResNet"])
    parser.add_argument("--test_size",type=int,default=10)
     
    options=parser.parse_args()
    tester=TestTester(options)
    tester.test() 



