from __future__ import absolute_import, division, print_function
from torch.utils.tensorboard import SummaryWriter
#from tensorboardX import SummaryWriter
import os
import matplotlib.pyplot as plt
import math
from tqdm import tqdm, trange
import time
import os

import torch

from comp_Optimizer import comp_Optimizer
import networks
import datasets
import samplers
import actions
import VisualizationUtils

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

#torch.autograd.set_detect_anomaly(True)
torch.backends.cudnn.enabled=True
torch.backends.cudnn.benchmark=True
def learning_rate(init, epoch):
    optim_factor = 0
    if(epoch > 160):
        optim_factor = 3
    elif(epoch > 120):
        optim_factor = 2
    elif(epoch > 60):
        optim_factor = 1

    return init*math.pow(0.2, optim_factor)

class ParamTrainer:
   def __init__(self,options):
       self.opt = options
       #self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)
       self.device = 'cpu' if self.opt.no_cuda else 'cuda'
       print("Device: ", self.device) 
       self.sampler_dict = {'AffineSampler': samplers.AffineSampler,
                            'ParamSamplerCol': samplers.ParamSamplerCol}

       self.action_dict = {'AffineActions': actions.AffineActions,
                           'ParamActionsCol': actions.ParamActionsCol}

       self.classifier_dict = {"ResNet":networks.ResNet18,
                               "ResNet34":networks.ResNet34,
                               "smallnet":networks.smallnet,
                               "wideResNet":networks.WideResNet}
       
        
       self.writer = SummaryWriter(self.opt.log_dir) if not(self.opt.no_vis) else None
       if self.writer is not None:
           print("Log Directory: ", self.opt.log_dir)

       self.visualizer_dict={"Toy2d":VisualizationUtils.visualize2D,
                             "Toy3d":VisualizationUtils.visualize3D,
                             "Mnist":VisualizationUtils.visualizeMnist}

       self.dataset_dict = {"MNIST": torchvision.datasets.MNIST,
                            "CIFAR10": torchvision.datasets.CIFAR10,
                            "CIFAR100": torchvision.datasets.CIFAR100,
                            "rotMNIST": datasets.rotMNIST,
                            "SVHN": torchvision.datasets.SVHN}


   def train(self):
       # fix num_groups/dist_dim before sampler. 
       device=self.device
       #torch.autograd.set_detect_anomaly(True)

       rotated_gradient_reg=not(self.opt.no_rotating_reg)
 
       number_samples=self.opt.number_samples
       samples_per_group=self.opt.samples_per_group

       batch_size = self.opt.batch_size
       num_epochs = self.opt.num_epochs       
       warmup = self.opt.warmup
       lr_class = self.opt.learning_rate_class
       lr_sampler = self.opt.learning_rate_sampler
       lr_reg = self.opt.lambda_reg*lr_sampler
       lr_param=self.opt.lr_param
       number_augment=0 

       #plt.ion()
       fig,ax=plt.subplots()
       plot_every = self.opt.plot_every

       sampler = self.sampler_dict[self.opt.sampler](tau=self.opt.tau,device=self.device) # 3 !
       sampler = sampler.to(device) #fix that
       if(self.opt.set_const_group>-1):
           with torch.no_grad():
               new_theta=torch.zeros_like(sampler.pi.detach())
               exponent=torch.arange(new_theta.shape[1],device=device)
               new_theta[0]= torch.fmod(torch.div(self.opt.set_const_group,(2**exponent),
                                                rounding_mode="trunc"),2)
               #new_theta[0,self.opt.set_const_group]=1
               sampler.pi.copy_(new_theta)

       num_groups=sampler.dist_dim #preset fix that!                                                                                                          
       print(sampler.pi)
       data_thetas=torch.zeros(batch_size,num_groups,device=device)

       #mist_flag = 'cat' if self.opt.dataset=='Mnist' else 'bin'
       if(self.opt.theta_loss=='NLLLoss'):
           soft_loss=nn.NLLLoss(reduction='none')
       else:
           soft_loss=None
       
       theta_stat=comp_Optimizer(num_groups,only_pos=True,lr=lr_sampler,lr_reg=lr_reg,soft_loss=soft_loss,device=self.device,transf_sampler=sampler) 
       
       action = self.action_dict[self.opt.action](device=self.device)


       if (self.opt.dataset == "Mnist"):
           
           mean_normalizer = 0.1307
           std_normalizer = 0.3081
           transform = transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize((mean_normalizer,), (std_normalizer,))])
           dataset = self.dataset_dict["MNIST"](root='./datasets/',download=True,train=True,transform=transform)
           
           test_dataset = self.dataset_dict["MNIST"](root='./datasets/',download=True,train=False,transform=transform)
           if(self.opt.split_dataset>0):
               indx_list=[]
               for digit in range(0,10):
                   idx_digit=torch.nonzero((dataset.targets==digit).float())
                   number_digit=idx_digit.shape[0]
                   #print(idx_digit.shape)
                   #print(number_digit)
                   idx_digit=idx_digit[0:int(number_digit*self.opt.split_dataset)]
                   indx_list.append(idx_digit)

               total_indx=torch.cat(indx_list,dim=0).squeeze()
               #print(total_indx.shape)
               dataset.targets=dataset.targets[total_indx]
               dataset.data=dataset.data[total_indx]

           if self.opt.twodig is not None:
               digit1 = int(self.opt.twodig[0]) #gets label 0
               digit2 = int(self.opt.twodig[1]) #gets label 1 (fix 1 output later)
               print("New Dataset. Digits", digit1, "and", digit2)
               idx_d1 = dataset.targets==digit1
               idx_d2 = dataset.targets==digit2
               dataset.targets = dataset.targets[idx_d1+idx_d2]
               dataset.data = dataset.data[idx_d1+idx_d2]
               
               #label transformation digit1->0, digit2->1
               dataset.targets[dataset.targets==digit1] = 0
               dataset.targets[dataset.targets==digit2] = 1

               idx_d1 = test_dataset.targets==digit1
               idx_d2 = test_dataset.targets==digit2
               test_dataset.targets = test_dataset.targets[idx_d1+idx_d2]
               test_dataset.data = test_dataset.data[idx_d1+idx_d2]

               test_dataset.targets[test_dataset.targets==digit1] = 0
               test_dataset.targets[test_dataset.targets==digit2] = 1
           loss_cr = torch.nn.NLLLoss()
           action = self.action_dict[self.opt.action](device=self.device,image_size=28)


           
           net = self.classifier_dict[self.opt.classifier]().to(self.device)
           
       elif (self.opt.dataset == 'rotMNIST'):
           mean_normalizer = 0.1307
           std_normalizer = 0.3081
           print("Training on rotMnist")        
           dataset=self.dataset_dict[self.opt.dataset](mean_normalizer,std_normalizer,train=True).create_dataset()
           test_dataset=self.dataset_dict[self.opt.dataset](mean_normalizer,std_normalizer,train=False).create_dataset()

           loss_cr = torch.nn.NLLLoss()
           
           action = self.action_dict[self.opt.action](device=self.device,image_size=28)

          

           net = self.classifier_dict[self.opt.classifier]().to(self.device)
                      
       elif (self.opt.dataset== "SVHN"):
           mean_normalizer = 0.5
           std_normalizer = 0.5
           #transform = transforms.Compose([transforms.ToTensor(),
           #                                 transforms.Normalize((0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010))])
           transform_train = transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            #transforms.RandomRotation(30),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                            ])

           transform_test = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                            ])

 
           dataset = self.dataset_dict["SVHN"](root='./datasets/',download=True,split='train',transform=transform_train)
           
           test_dataset = self.dataset_dict["SVHN"](root='./datasets/',download=True,split='test',transform=transform_test)
           loss_cr = torch.nn.NLLLoss()
           
          
           net = self.classifier_dict[self.opt.classifier](10).to(self.device)
 
       elif (self.opt.dataset == "Cifar10"):
            
           mean_normalizer = 0.5
           std_normalizer = 0.5
           #transform = transforms.Compose([transforms.ToTensor(),
           #                                 transforms.Normalize((0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010))])
           transform_train = transforms.Compose([
                            #transforms.RandomCrop(32, padding=4),
                            #transforms.RandomHorizontalFlip(),
                            #transforms.RandomRotation(30),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                            ])

           transform_test = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                            ])

 
           dataset = self.dataset_dict["CIFAR10"](root='./datasets/',download=True,train=True,transform=transform_train)
           
           test_dataset = self.dataset_dict["CIFAR10"](root='./datasets/',download=True,train=False,transform=transform_test)
           if(self.opt.split_dataset>0):
               indx_list=[]
               for digit in range(0,10):
                   idx_digit=torch.nonzero((torch.tensor(dataset.targets)==digit).float())
                   number_digit=idx_digit.shape[0]
                   #print(idx_digit.shape)
                   #print(number_digit)
                   idx_digit=idx_digit[0:int(number_digit*self.opt.split_dataset)]
                   indx_list.append(idx_digit)

               total_indx=torch.cat(indx_list,dim=0).squeeze()
               #print(total_indx.shape)
               #dataset.targets=dataset.targets[total_indx]
               #dataset.data=dataset.data[total_indx]
               dataset = torch.utils.data.Subset(dataset, total_indx.int())


           #?
           if self.opt.twodig is not None:
               digit1 = int(self.opt.twodig[0]) #gets label 0
               digit2 = int(self.opt.twodig[1]) #gets label 1 (fix 1 output later)
               print("New Dataset. Digits", digit1, "and", digit2)
               idx_d1 = torch.nonzero((torch.tensor(dataset.targets)==digit1).float())
               idx_d2 = torch.nonzero((torch.tensor(dataset.targets)==digit2).float())
               
               total_indx = torch.cat([idx_d1,idx_d2],dim=0).squeeze()
               dataset = torch.utils.data.Subset(dataset, total_indx.int())
               #dataset.targets = dataset.targets[idx_d1+idx_d2]
               #dataset.data = dataset.data[idx_d1+idx_d2]
               
               #label transformation digit1->0, digit2->1
               dataset.targets[dataset.targets==digit1] = 0
               dataset.targets[dataset.targets==digit2] = 1

               idx_d1 = test_dataset.targets==digit1
               idx_d2 = test_dataset.targets==digit2
               test_dataset.targets = test_dataset.targets[idx_d1+idx_d2]
               test_dataset.data = test_dataset.data[idx_d1+idx_d2]

               test_dataset.targets[test_dataset.targets==digit1] = 0
               test_dataset.targets[test_dataset.targets==digit2] = 1

           loss_cr = torch.nn.NLLLoss()
           
          
           net = self.classifier_dict[self.opt.classifier]().to(self.device)
           

       else :
           mean_normalizer = 0.5
           std_normalizer = 0.25
           #
           transform_train = transforms.Compose([
                            #transforms.RandomCrop(32, padding=4),
                            #transforms.RandomHorizontalFlip(),
                            #transforms.RandomRotation(30),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5070, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
                            ])

           transform_test = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5070, 0.4865, 0.4408), (0.2675, 0.2565, 0.2761)),
                            ])

 
           dataset = self.dataset_dict["CIFAR100"](root='./datasets/',download=True,train=True,transform=transform_train)
           
           test_dataset = self.dataset_dict["CIFAR100"](root='./datasets/',download=True,train=False,transform=transform_test)
          
           loss_cr = torch.nn.NLLLoss()
           
          
           net = self.classifier_dict[self.opt.classifier](num_classes=100).to(self.device)
           
          

       
       #warmup_loader = DataLoader(dataset,batch_size=number_samples*batch_size,shuffle=True,num_workers=8)    
       basic_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=8)
       
       test_loader = DataLoader(test_dataset,batch_size=self.opt.test_size,shuffle=False,num_workers=4)
       visualizer= self.visualizer_dict[self.opt.visualizer]
       
       print(list(sampler.parameters()))
       if(not os.path.isfile(self.opt.log_dir+'/saved_model')):
        sampler_optim = optim.SGD(sampler.parameters(),lr_param)
        if self.opt.optimizer=='Adam':
            net_optim = optim.Adam(net.parameters(),lr_class)
            print("Using Adam optimizer")
        else:
            net_optim=optim.SGD(net.parameters(),lr_class,momentum=0.9, weight_decay=5e-4,nesterov=True)
            print("Using SGD optimizer")
        scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(net_optim,num_epochs,eta_min=0.0005)
        #scheduler=torch.optim.lr_scheduler.MultiStepLR(net_optim,milestones=[15],gamma=0.2)
        start_epoch=0
        global_step=0
       else:
        print("Continuing from checkpoint")
        checkpoint=torch.load(self.opt.log_dir+'/saved_model')
        start_epoch=checkpoint['epoch']+1
        global_step=start_epoch
        print("Epoch "+str(start_epoch))
        net.load_state_dict(checkpoint['model_params'])
        sampler_optim = optim.SGD(sampler.parameters(),lr_param)
        if self.opt.optimizer=='Adam':
            net_optim = optim.Adam(net.parameters(),lr_class)
            print("Using Adam optimizer")
        else:
            net_optim=optim.SGD(net.parameters(),lr_class,momentum=0.9, weight_decay=5e-4,nesterov=True)
            print("Using SGD optimizer")

        net_optim.load_state_dict(checkpoint['net_optim'])
        sampler.load_state_dict(checkpoint['sampler'])
        sampler_optim.load_state_dict(checkpoint['sampler_optim'])
        if start_epoch>num_epochs:
            theta_stat.lr=0
            theta_stat.lr_reg=0

            for i in range(num_epochs):
                theta_stat.schedule_step(self.opt.reg_sched,self.opt.exploration_rate)
    
            scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(net_optim,self.opt.extra_epochs,last_epoch=start_epoch-1-num_epochs)
        else:
            for i in range(start_epoch):
                theta_stat.schedule_step(self.opt.reg_sched,self.opt.exploration_rate)
    
            scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(net_optim,num_epochs,last_epoch=start_epoch-1)

       prob=[[] for gr in range(num_groups)]
       mean_regularizer=[]
       #train_loss_epoch_list = []

       slow_mix_coef=torch.tensor([1.0],device=self.device) if self.opt.slow_mix_rate<1 else torch.tensor([0.0],device=self.device)

       overall_reg_group=torch.zeros(1,num_groups,device=self.device)
       if(self.opt.no_augment):
           print("Training without augmentations")
       if(self.opt.exp_loss):
           print("NOT IMPLEMENTED: Training using the loss of the expectation")

       for epoch in range(start_epoch,num_epochs+self.opt.extra_epochs):
           
           if(epoch==num_epochs):
               theta_stat.lr=0
               theta_stat.lr_reg=0
       
               if self.opt.optimizer=='Adam':
                net_optim = optim.Adam(net.parameters(),lr_class)
               else :
                net_optim=optim.SGD(net.parameters(),lr_class,momentum=0.9, weight_decay=5e-4,nesterov=True)
              
               scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(net_optim,self.opt.extra_epochs)
           epoch_regularizer=0
           regularizer_steps=0
           tr_loss_epoch = 0
           tot_regularizer_per_group=torch.zeros(1,num_groups,device=self.device)
           mistakes=torch.zeros(num_groups,device=self.device)
           num_choices=torch.zeros(num_groups,device=self.device)

           train_loader = basic_loader #warmup_loader if (epoch<= warmup or self.opt.no_augment) else basic_loader
           #warmup_iter= iter(warmup_loader) if (epoch>warmup) else None

           for i,(rotated_data,y) in enumerate(tqdm(train_loader)):
                
               rotated_data, y = rotated_data.to(self.device), y.to(self.device)
               batch_size=rotated_data.shape[0]

               
               data_sample, theta_max, indexes=sampler(number_samples*batch_size)
               #print("Data Sample:",data_sample)
               
               pre_input=action(data_sample,rotated_data)
               
               input = pre_input.flatten(start_dim=0,end_dim=1)
             
               mix_input=input

               y_g = y.repeat(number_samples,)
               #sub_y_g=y_g[(1-mixture_samples).bool()]
               
               mix_y=y_g
               #mix_y=torch.cat([sub_y_g,sub_unaug_y],dim=0)

               out=net(mix_input)      
               
               theta_stat.update_stat(indexes.detach(), out.detach(), mix_y.detach())
               mistakes+=theta_stat.num_mistakes
               num_choices+=theta_stat.num_choices
            

               loss=loss_cr(out,mix_y)+self.opt.lambda_reg*torch.sum(sampler.pi[0,0:sampler.param_dim]*torch.log(1/sampler.get_param_theta()))



               tr_loss_epoch += loss.item()*out.shape[0] / (len(dataset)*number_samples)
               
               net_optim.zero_grad()
               sampler_optim.zero_grad()

               loss.backward()
               net_optim.step()
               #if(epoch<num_epochs and self.opt.set_const_group==-1 ):
               sampler_optim.step()
                
               #if (epoch>warmup and (not self.opt.no_augment)):
               grad_regularizer = sampler.get_param_theta().detach()

               with torch.no_grad():
                   theta_stat.lr_sampler = theta_stat.lr/(1-slow_mix_coef) if (1-slow_mix_coef)>0 else theta_stat.lr
                   new_theta, grad_loss_theta = theta_stat.update_pi(sampler.pi.detach(),grad_regularizer)
                   #print(grad_regularizer.shape)
                   #print(sampler.pi.shape)
                   sampler.pi.copy_(new_theta) 

           scheduler.step()         
           theta_stat.schedule_step(self.opt.reg_sched,self.opt.exploration_rate)
           torch.save({'epoch': epoch,
                        'model_params': net.state_dict(),
                        'net_optim': net_optim.state_dict(),
                        'sampler': sampler.state_dict(),
                        'sampler_optim': sampler_optim.state_dict()}, self.opt.log_dir+'/saved_model')
           if(epoch==num_epochs-1):
                torch.save({'epoch': epoch,
                        'model_params': net.state_dict(),
                        'net_optim': net_optim.state_dict(),
                        'sampler': sampler.state_dict(),
                        'sampler_optim': sampler_optim.state_dict()}, self.opt.log_dir+'/firstCycleSaved_model')

           #
           #sampler_scheduler.step()
           #tr_loss_epoch_list.append(tr_loss_epoch)
           print("Epoch "+str(epoch))
           print("Loss="+str(tr_loss_epoch))
           print("Epoch mistakes= "+str(mistakes))
           print("Epoch num choice= "+str(num_choices))
           if (self.opt.set_const_group==-1 and number_augment>0 and epoch>warmup and (not self.opt.no_augment)):
               print("Loss Grad Theta= "+str(grad_loss_theta))


           pi = sampler.pi.view(-1) #(torch.exp(sampler.logpi) / torch.sum(torch.exp(sampler.logpi))).view(-1)
           prob_dict={}
           theta_dict={}
           print("Categorical Probabilities= "+str(pi))
           learned_theta=sampler.get_param_theta().detach().reshape(-1)
           for i,names in enumerate(action.group_names):
               prob_dict[names]=pi[i].item()
               prob[i].append(pi[i].item())
               if(i<learned_theta.shape[0]):
                   theta_dict[names]=learned_theta[i].item()
           if self.writer is not None:
               for key in prob_dict:
                   print(key)
                   self.writer.add_scalar("Categorical_Probabilities/"+key,prob_dict[key],global_step=global_step)
               for key in theta_dict:    
                   self.writer.add_scalar("Uniform_Bounds/"+key,theta_dict[key],global_step=global_step)
                   
               self.writer.add_scalar("Losses/"+str(self.opt.dataset)+"_Train_Loss",tr_loss_epoch,global_step=global_step)


           if (epoch>warmup  and epoch%plot_every==0 ):
               visualizer(self, global_step, net, dataset, test_loader, prob, action.group_names, mean_regularizer, overall_reg_group, sampler=sampler, action=action)
           global_step+=1
