import torch
import torch.nn.functional as F
from torch import nn
import torchvision
from torchvision import datasets, models, transforms
from torchsummary import summary

import gc
import re
import time
import math
import random
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import numpy as np 
import os

import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('__file__'))))

import setting
from utils import *
from custom_dataset import cub
from explainer import vgg_lrps
from explainer.lrp_utils import *


class CLA_Layer(nn.Module):
    def __init__(self,
                 device,
                 idx,
                 layer_shape, 
                 nb_classes, 
                 nb_attrs,
                 sz_patch, 
                 sz_attr, 
                 lr,
                 max_epoches,
                 alpha, 
                 beta):
        
        super().__init__()
        self.device = device
        self.idx = idx
        self.feature_shape = layer_shape
        self.nb_classes = nb_classes
        self.nb_attrs = nb_attrs
        self.sz_patch = sz_patch
        self.sz_attr = sz_attr
        self.nb_attrs = nb_attrs
        self.max_epoches = max_epoches
        self.alpha = alpha
        self.beta = beta
        
        self.attr_shape, self.attr_vec, self.ones = self._init_attr_vector() 
        self.add_on_layers = self._init_add_on_layers()
       
        self._initialize_weights()
        
        self.opt = torch.optim.Adam(self.parameters(), lr=lr)
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=self.opt,
                                                           lr_lambda=lambda epoch: 0.95 ** epoch,
                                                           last_epoch=-1,
                                                           verbose=False)
        
    def _init_attr_vector(self):
        f_dim = self.feature_shape[1]+1 # +1: for icp channel
            
        attr_shape = (self.nb_attrs, f_dim, self.sz_attr, self.sz_attr)
        attr_vec = nn.Parameter(torch.rand(attr_shape).cuda(), requires_grad=True)
        one = nn.Parameter(torch.ones(attr_shape).cuda(), requires_grad=False)
        
        return attr_shape, attr_vec, one
        
    def _init_add_on_layers(self):
        f_channel = self.attr_shape[1]
        
        add_on_layers = nn.Sequential(
            nn.Conv2d(in_channels=f_channel, out_channels=f_channel, kernel_size=1),     
            nn.SiLU(),
            nn.Conv2d(in_channels=f_channel, out_channels=f_channel, kernel_size=1),
            nn.SiLU()
        )   
        
        # if self.sz_patch > 11:
        #     add_on_layers.add_module('4', nn.MaxPool2d(kernel_size=2))
        #     add_on_layers.add_module('5', nn.Conv2d(in_channels=f_channel, out_channels=f_channel, kernel_size=1))
        #     add_on_layers.add_module('6', nn.SiLU())
        #     add_on_layers.add_module('7', nn.Conv2d(in_channels=f_channel, out_channels=f_channel, kernel_size=1))
        #     add_on_layers.add_module('8', nn.SiLU())
        
        add_on_layers.to(self.device)
        
        return add_on_layers
    
    def _initialize_weights(self):
        for m in self.add_on_layers.modules():
            if isinstance(m, nn.Conv2d):
                # every init technique has an underscore _ in the name
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                # nn.init.xavier_normal_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def _get_attribute_distance(self, x):
        '''
            Apply self.attribute_vectors as l2-convolution filters on input x
            
            Args:
                x: input feature generated by add_on_layers
            Returns:
                distances: distance between x and each attribute vectors 
        '''
        # x2: (b, c, w, h) -> x의 l2
        x2 = x ** 2 
        
        # ones: (nb_classes*nb_attr, c, w, h)
        # x2_patch_sum: (b, nb_classes*nb_attr, w-sz_attr+1, h-sz_attr+1)
        x2_patch_sum = F.conv2d(input=x2, weight=self.ones)

        # a2: (nb_classes*nb_attr, c, sz_attr, sz_attr)
        a2 = self.attr_vec ** 2
        
        # a2: (nb_classes*nb_attr) -> 각 attr별로 제곱 합 (l2)
        a2 = torch.sum(a2, dim=(1, 2, 3))

        # a2_reshape: (nb_classes*nb_attr, 1, 1)
        a2_reshape = a2.view(-1, 1, 1)

        # xa: (b, nb_classes*nb_attr, w-sz_attr+1, h-sz_attr+1)
        xa = F.conv2d(input=x, weight=self.attr_vec)
        
        # intermediate_result: (b, nb_classes*nb_attr, w-sz_attr+1, h-sz_attr+1)
        intermediate_result = - 2 * xa + a2_reshape  # use broadcast

        # distances: (b, nb_attrs, w-sz_attr+1, h-sz_attr+1)
        distances = F.relu(x2_patch_sum + intermediate_result)
        
        return distances
    
    def _get_all_distances_for_subfeatures(self, a, p, y, icp):
        if y == None: icp = icp[p, :]
        else: icp = icp[y, :]
        
        # each sub_feature: (B, C, sz_patch, sz_patch) and add icp to last channel of each sub feature 
        # then, get distance with each sub_feature 
        distances: (B, nb_attrs, nb_sub_features, p-a+1, p-a+1)
        distances = torch.stack([self.forward(torch.cat((a[:, :, w:w+self.sz_patch, h:h+self.sz_patch].clone().detach(), 
                                                    icp), dim=1))
                                 for w in range(a.shape[2]-self.sz_patch+1)
                                 for h in range(a.shape[3]-self.sz_patch+1)], 
                                dim=2)
        
#         B, C, W, H = a.shape
#         nb_sf = (W-self.sz_patch+1) * (H-self.sz_patch+1)
#         # (b*nb_sf, nb_attrs, p-a+1, p-a+1)
#         distances = self.forward(torch.cat([(torch.cat((a[:, :, w:w+self.sz_patch, h:h+self.sz_patch], 
#                                                      icp), dim=1))
#                                          for w in range(W-self.sz_patch+1)
#                                          for h in range(H-self.sz_patch+1)], 
#                                         dim=0))
        
#         # distances = self.forward(sf_cat)  # (b*nb_sf, nb_attrs, p-a+1, p-a+1)
#         distances = distances.reshape([nb_sf, B, distances.shape[1], 
#                                          distances.shape[2], distances.shape[3]]).permute(1, 2, 0, 3, 4)
        
        return distances
    
    
    def _get_loss(self, y, p, d_pos_list, d_neg_list):
        valid_idx = get_valid_labels(y=y, p=p)
        
        # d_pos_list: list, len: nb_attrs, item shape: (b, nb_attrs, p-a+1, p-a+1)
        d_pos_diversity = torch.zeros_like(d_pos_list[0])
        d_neg_diversity = torch.zeros_like(d_neg_list[0])
        
        for i in range(self.nb_attrs):
            d_pos_diversity[:, i] = d_pos_list[i][:, i]
            d_neg_diversity[:, i] = d_neg_list[i][:, i]
        
        d_pos_diversity = d_pos_diversity.mean(dim=[2,3]).div(self.sz_attr).amax(dim=1)[valid_idx]
        d_neg_diversity = d_neg_diversity.mean(dim=[2,3]).div(self.sz_attr).mean(dim=1)[valid_idx]
        
        Delta = (self.beta*d_pos_diversity) - ((1-self.beta)*d_neg_diversity) 
        loss = F.softplus(self.alpha * Delta).mean()
        
        return loss, d_pos_diversity.mean(), d_neg_diversity.mean()

    def forward(self, x):
        """
            Propagate CLA_layer with input x
            
            Args: 
                x: Input, f_pos or f_neg (B, C, sz_patch, sz_patch)
            
            Return 
                distances: Distance map for each attribute on x 
                           (B, nb_attrs, sz_patch-sz_attr+1, sz_patch-sz_attr+1)
        """
        x = x / (x.norm(2, 1, keepdim=True) + 1e-4) 
        x = self.add_on_layers(x)
        
        return self._get_attribute_distance(x)
    
    def train(self, y, p, f_pos_list, f_neg_list):
        dist_pos_list = [self.forward(f_pos) for f_pos in f_pos_list]
        dist_neg_list = [self.forward(f_neg) for f_neg in f_neg_list]
        
        loss, pos_mean, neg_mean = self._get_loss(y, p, dist_pos_list, dist_neg_list)
        
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        
        return loss, pos_mean, neg_mean   
    
    def predict(self, a, p, y, icp):
        """
            Predict minimum x and y coordinate from each attribute 
            
            Args:
                a: Target activation; (B, C, W, H)
                icp: Icp for patch size of current attribute 
                p: Prediction
                y: Label, default is None, if y is not None, y will be used in place of prediction 
                
            Returns:
                min_dist_xy: Coordinate of x and y of each attribute for given activation; (B, nb_attrs, 2)
        """
        B, C, W, H = a.shape
        
        distances = self._get_all_distances_for_subfeatures(a=a, p=p, y=y, icp=icp)
        
        # (B, nb_attrs, 2)
        attr_dist_xy = np.zeros([B, self.nb_attrs, 2], dtype=int)
        attr_dist_map = torch.zeros([B, self.nb_attrs, distances.shape[-1], distances.shape[-1]], dtype=float)
        dw = W-self.sz_patch+1
        dh = W-self.sz_patch+1
        for b, dist_b in enumerate(distances):
            # sub feature index with minimum distance for each attribute 
            attr_dist_xy_avg = [dist_attr.mean(dim=[1,2]).argmin(dim=0) for dist_attr in dist_b]
            # min_dist_val_by_attr = [dist_attr.mean(dim=[1,2]).amin(dim=0) for dist_attr in dist_b]
                
            for i in range(len(attr_dist_xy_avg)):
                attr_dist_xy[b, i, 0] =  attr_dist_xy_avg[i] // dw
                attr_dist_xy[b, i, 1] =  attr_dist_xy_avg[i] % dh
                
                attr_dist_map[b, i] = dist_b[i, attr_dist_xy_avg[i].item()].detach().cpu()
                
        return attr_dist_xy, attr_dist_map