import argparse

import torch
from torch import nn
from data import data_helper
from optimizer.optimizer_helper import get_optim_and_scheduler
from utils.Logger import Logger
from utils.utils import set_requires_grad
import numpy as np
from models.resnet import resnet18, resnet50, weights_init
from models.algorithms import Classifier, SupspaceIndicator, Sigmoid_Classifier
import random
import os
import torch.nn.functional as F
from torch.distributions import Bernoulli, RelaxedBernoulli

class GradientReversalFunction(torch.autograd.Function):
	@staticmethod
	def forward(ctx, x):
		return x.clone()
	@staticmethod
	def backward(ctx, grads):
		dx = - grads
		return dx, None

class Trainer:
	def __init__(self, args, device):
		self.args = args
		self.args.trainer = 'ICLR_{}_{}'.format(self.args.network, self.args.image_size)
		print(self.args)
		self.device = device
		
		if self.args.network == 'ResNet18':
			model = resnet18(pretrained=True, classes=self.args.n_classes)
		else:
			model = resnet50(pretrained=True, classes=self.args.n_classes)
		self.model = model.to(device)
		
		if self.args.sigmoid_head:
			classifier = Sigmoid_Classifier(self.args.feature_dim, self.args.n_classes) #13703
		else:
			classifier = Classifier(self.args.feature_dim, self.args.n_classes)
		self.classifier = classifier.to(device)


		supspace_indicator = SupspaceIndicator(self.args.feature_dim)	
		self.supspace_indicator = supspace_indicator.to(self.device)
		self.supspace_indicator.apply(weights_init)

		self.optimizer, self.scheduler = get_optim_and_scheduler(list(model.parameters())+list(classifier.parameters()), self.args.epochs, self.args.learning_rate)
		

		self.optimizer_attetion, self.scheduler_supspace_indicator = get_optim_and_scheduler(list(supspace_indicator.parameters()), self.args.epochs, self.args.mask_learning_rate)
		

		################################
		self.source_loader, self.val_loader = data_helper.get_train_dataloader(self.args, patches=model.is_patch_based())
		self.target_loader = data_helper.get_val_dataloader(self.args, patches=model.is_patch_based())
		self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
		self.len_dataloader = len(self.source_loader)
		print("Dataset size: train %d, val %d, test %d" % (
		len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))
		
		################################
		if self.args.target in self.args.source:
			self.target_id = self.args.source.index(self.args.target)
			print("Target in source: %d" % self.target_id)
			print(self.args.source)
		else:
			self.target_id = None

		self.results = {}
		self.results["val"] = []
		self.results["test"] = []
		self.utilized_features = {}
		self.utilized_features["val"] = []
		self.utilized_features["test"] = []
		self.best_soft = 0.0
		if self.args.freeze_bn:
			self.model.freeze_bn() # sig 13714 13718 sof 13713 13719 /1.0:  sig 13724  sof 13723

	def _do_epoch(self, epoch):
		print(epoch)
		criterion_CEL = nn.CrossEntropyLoss()
		if self.args.sigmoid_head:
			criterion = nn.BCELoss()
		else:
			criterion = nn.CrossEntropyLoss()

		self.model.train()
		self.classifier.train()
		self.supspace_indicator.train()


		for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader):
			self.logger.current_iter += 1
			data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(
					self.device), d_idx.to(self.device)	

			data_flip = torch.flip(data, (3,)).detach().clone()	
			data = torch.cat((data, data_flip))
			class_l = torch.cat((class_l, class_l))
			jig_l = torch.cat((jig_l, jig_l))
			d_idx = torch.cat((d_idx, d_idx))

			################################
			for sample_step in range(self.args.ber_sample):
				self.optimizer.zero_grad()
				self.optimizer_attetion.zero_grad()
				features, att_z = self.model(data)

				#import pdb; pdb.set_trace()
				mask, _ = self.supspace_indicator(att_z)
				prob = Bernoulli(mask)
				bernoulli_mask = prob.sample()	
				#bernoulli_mask = (mask >= 0.6-1e-3).float() #13476 13481 13555 13587

				class_logit = self.classifier(nn.Dropout(self.args.drop_rate)(bernoulli_mask*features))
				#class_logit = self.classifier(bernoulli_mask*features)
				if self.args.sigmoid_head:
					class_loss = criterion(class_logit, jig_l)
				else:
					class_loss = criterion(class_logit, class_l)

				_, cls_pred = class_logit.max(dim=1)	
				self.logger.log(it, len(self.source_loader),
								{"H_m": class_loss.item()},
								{"H_m": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0])
				self.logger.log_debug({"used_features": torch.round(mask).sum(-1).mean()}, False)
				
				total_loss = class_loss
			
				total_loss.backward()
				self.optimizer.step()

			################################
			self.optimizer_attetion.zero_grad()
			self.optimizer.zero_grad()

			features, att_z = self.model(data)

			mask, mlogit = self.supspace_indicator(att_z)
			
			class_logit = self.classifier(mask*features.detach()) # we use att_z.grad latter, don't want grad from feature
			if self.args.sigmoid_head:
				class_loss = criterion(class_logit, jig_l)
			else:
				class_loss = criterion(class_logit, class_l)

			total_loss = class_loss
			 
			total_loss.backward()
			self.optimizer_attetion.step()

			
			self.logger.log_loss( {"sup-space indicator": total_loss})


	def do_test(self, loader, value):
		class_correct = 0
		used_features = 0
		count = 0
		for it, ((data, nouse, class_l), _) in enumerate(loader):
			count+=1
			data, nouse, class_l = data.to(self.device), nouse.to(self.device), class_l.to(self.device)

			features, att_z = self.model(data)
			mask, _ = self.supspace_indicator(att_z)
			if value >= 0.0:
				hard = (mask >= value-1e-3).float()
			else: 
				hard = mask
			class_logit = self.classifier(hard*features)
			used_features += torch.round(hard).sum(-1).mean()			
			_, cls_pred = class_logit.max(dim=1)

			class_correct += torch.sum(cls_pred == class_l.data)

		return class_correct, used_features/count


	def do_training(self):
		self.logger = Logger(self.args, update_frequency=30)
		
		for self.current_epoch in range(self.args.epochs):
			
			self.scheduler.step()
			self.scheduler_supspace_indicator.step()
			
			self.logger.new_epoch(self.scheduler.get_lr())
			self._do_epoch(self.current_epoch)


			if (self.current_epoch + 1) % self.args.evaluation_iter == 0:
				self.model.eval()
				self.classifier.eval()
				self.supspace_indicator.eval()
				with torch.no_grad():
					for value in self.args.tau:
						for phase, loader in self.test_loaders.items():
							if phase == 'test':
								total = len(loader.dataset)

								class_correct, used_features = self.do_test(loader, value)

								class_acc = float(class_correct) / total
								self.logger.log_test(phase, {"accuracies_{}_{}".format(self.args.drop_rate, value): class_acc})
								self.logger.log_test(phase, {"used_features_{}_{}".format(self.args.drop_rate, value): used_features})
								self.results[phase].append(class_acc)
								self.utilized_features[phase].append(float(used_features))

		# torch.save(self.model.state_dict(), 'iclr_checkpoints/{}_{}_model.pth.tar'.format(self.args.network, self.args.target))
		# torch.save(self.classifier.state_dict(), 'iclr_checkpoints/{}_{}_classifier.pth.tar'.format(self.args.network, self.args.target))
		# torch.save(self.supspace_indicator.state_dict(), 'iclr_checkpoints/{}_{}_supspace_indicator.pth.tar'.format(self.args.network, self.args.target))

		return self.results, self.utilized_features
