import torch
import numpy as np
from PIL import Image
import os
from efficientnet_pytorch import EfficientNet
#from kornia.losses import focal
import torch.nn as nn

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from scipy.special import lambertw



def MSE(logits, label, epoch, features, thrs_epochs, Alpha_Emb):
        target = torch.zeros_like(logits)
        target[torch.arange(target.size(0)).long(), label] = 1
        out =  0.5*(((logits-target)**2).sum(dim=1)).mean()
        return out

def CE(logits, label, epoch, features, thrs_epochs, Alpha_Emb):
        CE_loss = torch.nn.CrossEntropyLoss()
        loss = CE_loss(logits, label)
        return loss

def MSE_CE(logits, label, epoch, features, threshold_on_epoch, Alpha_Emb):
        sigmoid_coef = torch.nn.Sigmoid()
        eta=sigmoid_coef(torch.tensor(epoch-threshold_on_epoch,  dtype=torch.float64).cuda())
        
        # if activation=='Tanh':
        #     if DATASET == 'cifar10':
        #         threshold_on_epoch=7
        #     elif DATASET == 'fmnist':  
        #         threshold_on_epoch=0
        #     elif DATASET == 'mnist':
        #         threshold_on_epoch=2  
        # elif activation=='ReLU':
        #     if DATASET == 'cifar10':
        #         threshold_on_epoch=19
        #     elif DATASET == 'fmnist':  
        #         threshold_on_epoch=8
        #     elif DATASET == 'mnist':
        #         threshold_on_epoch=18 
        CE_loss = CE(logits, label, epoch, features, threshold_on_epoch, Alpha_Emb)
        MSE_loss = MSE(logits, label, epoch, features, threshold_on_epoch, Alpha_Emb)
        Total_loss = (eta)*CE_loss+(1-eta)*MSE_loss 
        
        return Total_loss

def MSE_vec(logits, label):
        target = torch.zeros_like(logits)
        target[torch.arange(target.size(0)).long(), label] = 1
        out =  0.5*(((logits-target)**2).sum(dim=1))
        return out

def MSE_Focal(logits, label, epoch, features, threshold_on_epoch, Alpha_Emb):

        # Compute the probability of the ground-truth label for Focal loss
        loss = torch.nn.CrossEntropyLoss(reduction='none')
        ce_loss = loss(logits, label)
        pt = torch.exp(-ce_loss)
        gamma=5

        # Compute the coefficient of losses
        sigmoid_coef = torch.nn.Sigmoid()
        eta=sigmoid_coef(torch.tensor(epoch-threshold_on_epoch,  dtype=torch.float64).cuda())
        

        Focal_loss=((1-pt) ** gamma) * ce_loss
        
        MSE_loss= MSE_vec(logits, label)


        Total_loss= (eta*Focal_loss+ (1-eta)*MSE_loss).mean()  

        return Total_loss     


def L2(features):
        features = features.view(features.size(0),-1)
        features_L2 =  0.5*(features**2)
        features_L2_per_example =  features_L2.mean(dim=1)
        return features_L2_per_example


def MSE_Focal_L2(logits, label, epoch, features, threshold_on_epoch, Alpha_Emb):

        # Compute the probability of the ground-truth label for Focal loss
        loss = torch.nn.CrossEntropyLoss(reduction='none')
        ce_loss = loss(logits, label)
        pt = torch.exp(-ce_loss)
        gamma=5



        # Compute the coefficient of losses
        sigmoid_coef = torch.nn.Sigmoid()
        eta=sigmoid_coef(torch.tensor(epoch-threshold_on_epoch,  dtype=torch.float64).cuda())
        

        Focal_loss=((1-pt) ** gamma) * ce_loss
        
        MSE_loss= MSE_vec(logits, label)

        L2_loss=0


        for i in range(len(features)):
            L2_loss=L2_loss+L2(features[i])

        Total_loss= (eta*Focal_loss+ (1-eta)*MSE_loss + ((1-eta)/Alpha_Emb)*L2_loss).mean()

        return Total_loss 

  

def get_device():
        use_cuda = torch.cuda.is_available()
        assert use_cuda
        device = torch.device("cuda" if use_cuda else "cpu")
        return device


  


def train(model, train_loader, optimizer, epoch, max_grad_norm, log_angles, loss_function, act_type,loss_type,DATASET,grad_norm_per_example_all_itrs, weight_norm_per_epoch,grad_norm_per_epoch, cosine_per_itr, activation_norm, seed,thrs_epochs, Alpha_Emb, TRACK, n_acc_steps=1):
        device = next(model.parameters()).device
        model.train()
        num_examples = 0
        correct = 0
        train_loss = 0
        
        rem = len(train_loader) % n_acc_steps
        num_batches = len(train_loader)
        num_batches -= rem


        cos = nn.CosineSimilarity(dim=0,eps=1e-6)

        bs = train_loader.batch_size if train_loader.batch_size is not None else train_loader.batch_sampler.batch_size
        print(f"training on {num_batches} batches of size {bs}")
  


        grad_norm_for_angle=dict()
        for name, param in model.named_parameters():
                                grad_norm_for_angle[name]=[]
        
        # Weight l2 norm before the training
        for name, param in model.named_parameters():
                        weight_norm_per_epoch[name].append(param.data.norm(2).cpu().numpy().item())
        
        for batch_idx, (data, target) in enumerate(train_loader):

                
                if batch_idx > num_batches - 1:
                        break

                data, target = data.to(device), target.to(device)


                output, features = model(data)

                
                if TRACK:
                    # Save per-example activations and logits
                    for i in range(len(features)):
                        features[i] = features[i].view(features[i].size(0),-1)
                        activation_norm['features'].append(features[i].norm(2,dim=1).detach().cpu().numpy())
                    
                    activation_norm['output'].append(output.norm(2,dim=1).detach().cpu().numpy())

                #loss=loss_function(output, target, epoch)
                loss=loss_function(output, target, epoch, features, thrs_epochs, Alpha_Emb)
               

                loss.backward()

                if TRACK:
                    # Compute grad norm per example for all the layers
                    grad_norm_per_example=0
                    for name, param in model.named_parameters():
                            grad_norm_per_example+=param.grad_sample.view(param.grad_sample.size(0), -1).norm(2,dim=1)**2
                    grad_norm_per_example_all_itrs=grad_norm_per_example_all_itrs+list((grad_norm_per_example**.5).cpu().numpy())

                    # Compute grad per layer for all the examples
                    for name, param in model.named_parameters():            
                                    grad_norm_for_angle[name].append(param.grad_sample.view(param.grad_sample.size(0), -1))


                if ((batch_idx + 1) % n_acc_steps == 0) or ((batch_idx + 1) == len(train_loader)):
                

                        if TRACK:
                            ##### Track angle
                            # Concatante grad of all samples per layer
                            All_layer_grads_per_example = [torch.cat(grad_norm_for_angle[name]) for name, param in model.named_parameters()]        
                            
                            # Concatante grads of all layers
                            All_layer_grads_per_example = torch.cat(All_layer_grads_per_example,dim=1)
                            
                            # Compute grad norm per example
                            All_layer_grads_norm_per_example = All_layer_grads_per_example.norm(dim=1)

                            # Compute clipping factor per example
                            clip_coef = torch.min((All_layer_grads_norm_per_example*0+max_grad_norm) / (All_layer_grads_norm_per_example + 1e-6), (All_layer_grads_norm_per_example*0+1))
                            
                            # Clip grad per example
                            All_layer_grads_per_example_clipped= All_layer_grads_per_example * clip_coef[:,None]
                            
                            # Compute angle between aggregate of clipped grads and aggregate of non-clipped grads
                            cosine_similarity=cos(All_layer_grads_per_example.mean(dim=0),All_layer_grads_per_example_clipped.mean(dim=0))
                            
                            
                            del All_layer_grads_per_example
                            del All_layer_grads_norm_per_example
                            del All_layer_grads_per_example_clipped
                            del clip_coef
                            f1 = open(log_angles, 'a+')
                            
                            text = '{:.9f}\n'.format(cosine_similarity.cpu().numpy().item())
                            f1.write(text)
                            f1.close()

                            for name, param in model.named_parameters():            
                                    grad_norm_for_angle[name]=[]

                            ##### End of track angle

                        optimizer.step()

                        if TRACK:
                            # Save weight l2 norm after each iteration
                            for name, param in model.named_parameters():
                                    weight_norm_per_epoch[name].append(param.data.norm(2).cpu().numpy().item())
                
                        

                        optimizer.zero_grad()
                        
                else:
                        with torch.no_grad():
                                # accumulate per-example gradients but don't take a step yet
                                optimizer.virtual_step()

                pred = output.max(1, keepdim=True)[1]
                
                if TRACK:
                    # Save labels and predictions per iterations
                    saved_labels_path='LabelsResults'
                    if not os.path.isdir('{}/'.format(saved_labels_path)):
                            os.makedirs('{}/'.format(saved_labels_path))
                    np.savez_compressed('{}/Labels{}_loss{}_Epoch{}_Itrs{}_act{}.npz'.format(saved_labels_path,DATASET,loss_type,epoch,batch_idx,act_type),Pred=pred.squeeze().data.cpu().numpy(),Label=target.cpu().numpy())
                    
                    # Save labels and predictions per iterations
                    saved_logits_path='LogitsResults'
                    if not os.path.isdir('{}/'.format(saved_logits_path)):
                            os.makedirs('{}/'.format(saved_logits_path))
                    np.savez_compressed('{}/Labels{}_loss{}_Epoch{}_Itrs{}_act{}.npz'.format(saved_logits_path,DATASET,loss_type,epoch,batch_idx,act_type),Logit=output.data.cpu().numpy(),Label=target.cpu().numpy())  
                
                # if epoch >28:
                   
                #   pred_save=pred.squeeze().cpu().numpy()
                #   true_save=target.cpu().numpy()

                #   misclassified_idxs=np.where(pred_save!=true_save)
                #   true_save_mis=true_save[misclassified_idxs]
                #   pred_save_mis=pred_save[misclassified_idxs]
                #   misclassified_imgs=data[misclassified_idxs]

                #   for mis_img_idx in range(len(misclassified_imgs)):
                        
                #     mis_img=misclassified_imgs[mis_img_idx].cpu().numpy().transpose(1, 2, 0)

                #     mean = [0.485, 0.456, 0.406]
                #     std = [0.229, 0.224, 0.225]

                #     for c in range(3):
                #         mis_img[:,:,c] *= std[c]
                #         mis_img[:,:,c] += mean[c]
                #     mis_img[mis_img > 1] = 1
                #     mis_img[mis_img < 0] = 0

                #     mis_img=(mis_img*255).astype(np.uint8)
                #     mis_img_save = Image.fromarray(mis_img)
                #     saved_misclassified_imgs='imgs_{}_{}'.format(loss_type,DATASET)
                #     if not os.path.isdir('{}/'.format(saved_misclassified_imgs)):
                #         os.makedirs('{}/'.format(saved_misclassified_imgs))
                #     mis_img_save.save("{}/E-{}_Idx-{}_t-{}_p-{}.png".format(saved_misclassified_imgs,epoch,batch_idx,true_save_mis[mis_img_idx],pred_save_mis[mis_img_idx]))
 
                correct += pred.eq(target.view_as(pred)).sum().item()
                num_examples += len(data)
                #train_loss += loss_function(output, target, epoch).item()*len(data)
                train_loss += loss_function(output, target, epoch, features, thrs_epochs, Alpha_Emb).item()*len(data)

        train_loss /= num_examples
        train_acc = 100. * correct / num_examples

        print(f'Train set: Average loss: {train_loss:.4f}, '
                        f'Accuracy: {correct}/{num_examples} ({train_acc:.2f}%)')



        return train_loss, train_acc, grad_norm_per_example_all_itrs, weight_norm_per_epoch, grad_norm_per_epoch, cosine_per_itr, activation_norm


def test(model, test_loader, loss_function, epoch, DATASET,act_type,thrs_epochs, Alpha_Emb,):
        device = next(model.parameters()).device
        model.eval()
        num_examples = 0
        test_loss = 0
        correct = 0

        with torch.no_grad():
                for data, target in test_loader:
                        data, target = data.to(device), target.to(device)
                        output, features = model(data)
                        num_examples += len(data)
                        #test_loss += loss_function(output, target, epoch).item()*len(data)
                        test_loss += loss_function(output, target, epoch, features, thrs_epochs, Alpha_Emb).item()*len(data)
                        pred = output.max(1, keepdim=True)[1]
                        correct += pred.eq(target.view_as(pred)).sum().item()
                        

        test_loss /= num_examples
        test_acc = 100. * correct / num_examples

        print(f'Test set: Average loss: {test_loss:.4f}, '
                  f'Accuracy: {correct}/{num_examples} ({test_acc:.2f}%)')

        return test_loss, test_acc
