import argparse
import torch
from torch import nn
from data import data_helper
from optimizer.optimizer_helper import get_optim_and_scheduler
from utils.Logger import Logger
import numpy as np
from models.caffenet import caffenet
from models.resnet import weights_init
from models.algorithms import Classifier, SupspaceIndicator, Sigmoid_Classifier
import random
import os
import torch.nn.functional as F
from torch.distributions import Bernoulli, RelaxedBernoulli


class Trainer:
	def __init__(self, args, device):
		self.args = args
		self.args.trainer = 'ICLR_Alexnet_{}_{}'.format(self.args.partition, self.args.image_size)
		print(self.args)
		self.device = device
		
		model = caffenet(self.args.partition)
		self.model = model.to(device)
		if self.args.sigmoid_head:
			classifier = Sigmoid_Classifier(self.args.feature_dim, self.args.n_classes)
		else:
			classifier = Classifier(self.args.feature_dim, self.args.n_classes)
		self.classifier = classifier.to(device)

		self.supspace_indicator_dim = int(self.args.feature_dim/ self.args.partition)
		supspace_indicator = SupspaceIndicator(self.args.partition, k=self.supspace_indicator_dim)
		self.supspace_indicator = supspace_indicator.to(self.device)
		self.supspace_indicator.apply(weights_init)
		
		self.optimizer, self.scheduler = get_optim_and_scheduler(list(model.parameters())+list(classifier.parameters()), self.args.epochs, self.args.learning_rate)
		
		self.optimizer_attetion, self.scheduler_supspace_indicator = get_optim_and_scheduler(list(supspace_indicator.parameters()), self.args.epochs, self.args.mask_learning_rate)

		################################
		self.source_loader, self.val_loader = data_helper.get_train_dataloader(self.args, patches=model.is_patch_based())
		self.target_loader = data_helper.get_val_dataloader(self.args, patches=model.is_patch_based())
		self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
		self.len_dataloader = len(self.source_loader)
		print("Dataset size: train %d, val %d, test %d" % (
		len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))
		
		################################
		if self.args.target in self.args.source:
			self.target_id = self.args.source.index(self.args.target)
			print("Target in source: %d" % self.target_id)
			print(self.args.source)
		else:
			self.target_id = None

		self.results = {}
		self.results["val"] = []
		self.results["test"] = []
		self.utilized_features = {}
		self.utilized_features["val"] = []
		self.utilized_features["test"] = []


	def _do_epoch(self, epoch=None):
		if self.args.sigmoid_head:
			criterion = nn.BCELoss()
		else:
			criterion = nn.CrossEntropyLoss()
		self.model.train()
		self.classifier.train()
		self.supspace_indicator.train()

		for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader):
			self.logger.current_iter += 1
			data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(
					self.device), d_idx.to(self.device)
			################################

			data_flip = torch.flip(data, (3,)).detach().clone()	
			data = torch.cat((data, data_flip))
			class_l = torch.cat((class_l, class_l))
			jig_l = torch.cat((jig_l, jig_l))
			d_idx = torch.cat((d_idx, d_idx))

			self.optimizer.zero_grad()
			self.optimizer_attetion.zero_grad()
			features, att_z = self.model(data)
			mask, _ = self.supspace_indicator(att_z)	
			prob = Bernoulli(mask)
			bernoulli_mask = prob.sample()
			
			bernoulli_mask = bernoulli_mask.unsqueeze(-1).repeat(1,1,self.supspace_indicator_dim).reshape(data.shape[0], self.args.feature_dim)


			class_logit = self.classifier(nn.Dropout(self.args.drop_rate)(bernoulli_mask*features))
			if self.args.sigmoid_head:
				class_loss = criterion(class_logit, jig_l)
			else:
				class_loss = criterion(class_logit, class_l)

			_, cls_pred = class_logit.max(dim=1)
			
			self.logger.log(it, len(self.source_loader),
							{"H_m": class_loss.item()},
							{"H_m": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0])
			
			class_loss.backward()
			self.optimizer.step()

			################################
			self.optimizer_attetion.zero_grad()
			self.optimizer.zero_grad()

			features, att_z = self.model(data)
		
			mask, mlogit = self.supspace_indicator(att_z.detach())
			mask = mask.unsqueeze(-1).repeat(1,1,self.supspace_indicator_dim).reshape(data.shape[0], self.args.feature_dim)
			
			# label classifier
			class_logit = self.classifier(mask*features)
			if self.args.sigmoid_head:
				class_loss = criterion(class_logit, jig_l)
			else:
				class_loss = criterion(class_logit, class_l)
			total_loss = class_loss
			
			total_loss.backward()
			self.optimizer_attetion.step()
			
			################################
			self.logger.log_loss( {"sup-space indicator": total_loss})

		self.model.eval()
		self.classifier.eval()
		self.supspace_indicator.eval()
		with torch.no_grad():
			for value in self.args.tau:
				for phase, loader in self.test_loaders.items():
					if phase == 'test':
						total = len(loader.dataset)

						class_correct, used_features = self.do_test(loader, value)

						class_acc = float(class_correct) / total
						self.logger.log_test(phase, {"accuracies_{}".format(value): class_acc})
						self.logger.log_test(phase, {"used_features_{}".format(value): used_features})
						self.results[phase].append(class_acc)
						self.utilized_features[phase].append(float(used_features))
			
	def do_test(self, loader, value):
		class_correct = 0
		used_features = 0
		count = 0
		for it, ((data, nouse, class_l), _) in enumerate(loader):
			count+=1
			data, nouse, class_l = data.to(self.device), nouse.to(self.device), class_l.to(self.device)

			features, att_z = self.model(data)
			mask, _ = self.supspace_indicator(att_z)
			if value >= 0.0:
				hard = (mask >= value-1e-3).float()
			else: 
				hard = mask
			hard = hard.unsqueeze(-1).repeat(1,1,self.supspace_indicator_dim).reshape(data.shape[0], self.args.feature_dim)

			class_logit = self.classifier(hard*features)
			used_features += torch.round(hard).sum(-1).mean()			
			_, cls_pred = class_logit.max(dim=1)

			class_correct += torch.sum(cls_pred == class_l.data)

		return class_correct, used_features/count


	def do_training(self):
		self.logger = Logger(self.args, update_frequency=30)
		
		for self.current_epoch in range(self.args.epochs):
			
			self.scheduler.step()
			self.scheduler_supspace_indicator.step()
			
			self.logger.new_epoch(self.scheduler.get_lr())
			self._do_epoch(self.current_epoch)
		

		return self.results, self.utilized_features
