import os
import logging
from random import shuffle
from turtle import forward
import torch
import numpy as np
import pandas as pd
import cv2
from PIL import Image
from torchvision import models, transforms
from matplotlib import pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from skimage.color import rgb2gray
from sklearn import preprocessing
from pdb import set_trace as bp

#Fix the seeds for deterministic behavior

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# min max scaler for use
min_max_scaler = preprocessing.MinMaxScaler()

#student network uses downscaled images for highre receptive field
def get_student_transform(): 
    transf = transforms.Compose([
        # transforms.ToPILImage(),
        # transforms.ColorJitter(hue=.2, saturation = 0.2, contrast = 0.2),
        transforms.Resize((128, 128)),
        transforms.CenterCrop(128),
        # transforms.ToTensor()
    ]) 
    return transf

#incorporating color diversity is important to overcome color sensitivity
#RGB2HSV
def rgb2hsv(input, epsilon=1e-10):
    assert(input.shape[1] == 3)

    r, g, b = input[:, 0], input[:, 1], input[:, 2]
    max_rgb, argmax_rgb = input.max(1)
    min_rgb, argmin_rgb = input.min(1)

    max_min = max_rgb - min_rgb + epsilon

    h1 = 60.0 * (g - r) / max_min + 60.0
    h2 = 60.0 * (b - g) / max_min + 180.0
    h3 = 60.0 * (r - b) / max_min + 300.0

    h = torch.stack((h2, h3, h1), dim=0).gather(dim=0, index=argmin_rgb.unsqueeze(0)).squeeze(0)
    s = max_min / (max_rgb + epsilon)
    v = max_rgb

    if h.shape[0]==0: pass
    else: h, s, v = h/h.max(), s/s.max(), v/v.max()

    return torch.stack((h, s, v), dim=1)

#HSV2RGB
def hsv2rgb(input):
    assert(input.shape[1] == 3)

    h, s, v = input[:, 0], input[:, 1], input[:, 2]
    h_ = (h - torch.floor(h / 360) * 360) / 60
    c = s * v
    x = c * (1 - torch.abs(torch.fmod(h_, 2) - 1))

    zero = torch.zeros_like(c)
    y = torch.stack((
        torch.stack((c, x, zero), dim=1),
        torch.stack((x, c, zero), dim=1),
        torch.stack((zero, c, x), dim=1),
        torch.stack((zero, x, c), dim=1),
        torch.stack((x, zero, c), dim=1),
        torch.stack((c, zero, x), dim=1),
    ), dim=0)

    index = torch.repeat_interleave(torch.floor(h_).unsqueeze(1), 3, dim=1).unsqueeze(0).to(torch.long)
    rgb = (y.gather(dim=0, index=index) + (v - c)).squeeze(0)
    return rgb

#RGB2BGR
def rgb2bgr(input):
    assert len(input.shape)==4
    input[:,[0,1,2], :, :] = input[:,[2,1,0], :, :]
    return input

#DAME mask learning network
class MaskLearningNetwork(nn.Module):
    def __init__(self):
        super(MaskLearningNetwork, self).__init__()
        self.conv1 = self._conv_layer_set(3, 32)
        self.conv2 = self._conv_layer_set(32, 1)
        self.conv3 = self._conv_layer_set(3, 32)
        self.conv4 = self._conv_layer_set(32, 1)
        self.conv5 = self._conv_layer_set(3, 32)
        self.conv6 = self._conv_layer_set(32, 1)
        #self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(8, stride=8)
        self.upsample = nn.Upsample(scale_factor=8, mode='nearest')
        
    
    def _conv_layer_set(self, in_c, out_c):
        n_kernel = (5,5)
        n_pad = int(n_kernel[0]/2)
        conv_layer = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=n_kernel, padding=n_pad),
        nn.LeakyReLU(),
        )
        return conv_layer

    def forward(self, ip_images):
        out_mask1 = self.conv2(self.conv1(ip_images))
        out_mask1 = self.relu(out_mask1)
        out_mask2 = self.conv4(self.conv3(ip_images))
        out_mask2 = self.relu(out_mask2)
        out_mask3 = self.conv6(self.conv5(ip_images))
        out_mask3 = self.relu(out_mask3)
        out_mask = torch.cat((out_mask1, out_mask2, out_mask3), dim=1)
        
        out_mask = self.maxpool(out_mask)
        out_mask = self.upsample(out_mask)
        out_mask = torch.stack([self.relu(i/torch.max(i)) for i in out_mask])
        return out_mask

#DAME student regression network
class StudentClassifier(nn.Module):
    def __init__(self):
        super(StudentClassifier, self).__init__()
        self.conv1 = self._conv_layer_set(3, 32)
        self.conv2 = self._conv_layer_set(32, 8)
        self.fc1 = nn.Linear(8192, 128)
        self.fc2 = nn.Linear(128, 1)
        self.relu = nn.LeakyReLU()
        self.batch=nn.BatchNorm1d(128)
        self.drop=nn.Dropout(p=0.1)
        self.sigmoid=nn.Sigmoid()
    
    def _conv_layer_set(self, in_c, out_c):
        n_kernel = (3,3)
        n_pad = int(n_kernel[0]/2)
        conv_layer = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=n_kernel, padding=n_pad),
        nn.LeakyReLU(),
        nn.MaxPool2d((2, 2)),
        )
        return conv_layer

    def forward(self, ip_images):
        # bp()
        out = self.conv1(ip_images)
        out = self.conv2(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.batch(out)
        out = self.drop(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out

#DAME student model
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.loss = nn.MSELoss(reduction='none')
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')
        self.mask_compute = MaskLearningNetwork()
        self.student_classifier = StudentClassifier()
        self.bins = 10
        self.min = 0
        self.max = 1
        self.sigma = 0.05
        self.delta = float(self.max - self.min) / float(self.bins)
        self.centers = float(self.min) + self.delta * (torch.arange(self.bins).float() + 0.5)
        
    def predict_proba(self, ip_images, ind, epoch, train):
        out_mask = self.mask_compute(ip_images)
        masked_img = ip_images*out_mask
        out_prob = self.student_classifier(masked_img)
        mask_for_masking = None
        #############
        if ind==0 and train==True:
            
            max_pool = nn.MaxPool2d(8, stride=8)
            up_sample = nn.Upsample(scale_factor=8, mode='bilinear')
            
            mask_tosave = up_sample(max_pool(((out_mask[0][0]+out_mask[0][1]+out_mask[0][2])/3).unsqueeze(dim=0).unsqueeze(dim=0))).squeeze(dim=0).squeeze(dim=0).detach().cpu().numpy()
            
            plt.imsave('results/masks/mask'+str(epoch)+'.png', mask_tosave, cmap='gray', vmax=1, vmin=0)
            
            ip_image_for_masking = ip_images[0].permute(1,2,0).detach().cpu().numpy()
            ip_image_for_masking = (ip_image_for_masking-ip_image_for_masking.min())/(ip_image_for_masking.max()-ip_image_for_masking.min())
            mask_for_masking = (mask_tosave-mask_tosave.min())/(mask_tosave.max()-mask_tosave.min())


            ip_image_for_masking[mask_for_masking<mask_for_masking.mean()+1.2*mask_for_masking.std()] = 0.6
            plt.imsave('results/masks/masked_img'+str(epoch)+'.png', ip_image_for_masking)
            
            mask_tosave = ((mask_tosave-mask_tosave.min())/(mask_tosave.max()-mask_tosave.min()))*256
            heatmap_img = cv2.applyColorMap(mask_tosave.astype(np.uint8), cv2.COLORMAP_JET)
            fin = cv2.addWeighted(heatmap_img, 1.0, ip_images[0].permute(1,2,0).detach().cpu().numpy().astype(np.uint8), 0.0, 0)
            RGBimage = cv2.cvtColor(fin, cv2.COLOR_BGR2RGB)
            PILimage = Image.fromarray(RGBimage)
            
            PILimage.save('results/masks/hmap'+str(epoch)+'.png', dpi=(172,172))
        #############
        return out_prob, out_mask, mask_for_masking

    def forward(self, ip_images, targets, label_idx, weights, ind, epoch, train):
        prob, out_mask, mask_for_masking = self.predict_proba(ip_images, ind, epoch, train)
        
        loss1 = torch.mean(weights.view(weights.shape[0],-1)*self.loss(prob, targets[:,label_idx].view(targets.shape[0],-1)))
        loss2 = torch.mean(torch.abs(out_mask))
        
        prob_norm = prob/prob.sum()
        target_norm = targets[:,label_idx].view(targets.shape[0],-1)
        target_norm = target_norm/target_norm.sum()

        # loss4 = torch.mean(weights.view(weights.shape[0],-1)*self.kl_loss(prob.log(), targets[:,label_idx].view(targets.shape[0],-1)))
        loss3 = torch.mean(weights.view(weights.shape[0],-1)*self.kl_loss(prob_norm.log(), target_norm))
        
        loss = loss1+0.001*loss2+0.02*loss3
        
        return loss, loss1, loss2, loss3, mask_for_masking

class StudentDataset(Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, inputs, labels, weights, device):
        'Initialization'
        self.labels = labels.to(device)
        self.inputs = inputs.to(device)
        self.weights = weights.to(device)
        

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.labels)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ip_x = self.inputs[index]

        y = self.labels[index]
        w = self.weights[index]

        return ip_x, y, w

def kernel(d, kernel_width=0.25):
    return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))

def train(ip_images, targets, label_idx, distances, device_id, random_seed):
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    print(f'Generating explanation for index: {label_idx}')
    weights = kernel(distances)
    
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger = logging.getLogger(__name__)

    device = torch.device(device_id if torch.cuda.is_available() else "cpu")
    
    img_orig = ip_images[0].detach().cpu().numpy()
    img_orig = (img_orig - img_orig.min())/(img_orig.max() - img_orig.min())
    plt.imsave('results/masks/img.png', img_orig)
    
    ip_images = torch.stack(ip_images).permute(0,3,1,2)

    
    targets = torch.tensor(targets)
    weights = torch.tensor(weights)
    #Split the train val data and hyperparams set
    X_train, X_val,y_train, y_val, w_train, w_val = train_test_split(ip_images, targets, weights, test_size=0.1, shuffle=False)
    BATCH_SIZE = 16
    EPOCHS = 10
    base_lr = 1e-3

    train_dataset = StudentDataset(X_train, y_train, w_train, device_id)
    val_dataset = StudentDataset(X_val, y_val, w_val, device_id)
    #Dataloaders created
    train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
    val_data_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    #model_anc.to(device)
    student_model = StudentModel()
    student_model.to(device)
    #model_trans.to(device)


    #Optimizer defined
    # optimizer = Adam([{'params':student_model.parameters()}], lr=base_lr)
    optimizer = SGD([{'params':student_model.parameters()}], lr=base_lr, momentum=0.9, nesterov=True)

    student_transf = get_student_transform()
    final_val_f1 = 0
    loss_epochs = []
    loss_epochs1, loss_epochs2, loss_epochs3 = [], [], []
    val_loss_epochs = []
    val_loss_epochs1, val_loss_epochs2, val_loss_epochs3 = [], [], []
    mask_expls = []
    for epoch in range(EPOCHS):
        #bp()
        tot_loss, val_loss = 0.0, 0.0
        tot_loss1, tot_loss2, tot_loss3 = 0.0, 0.0, 0.0
        val_loss1, val_loss2, val_loss3 = 0.0, 0.0, 0.0
        outputs = []
        targets = []
        #model_anc.train()
        student_model.train()
        #model_trans.train()
        for ind, (local_x, local_y, local_w) in enumerate(train_data_loader):
            # bp()
            # local_x = torch.stack([student_transf(x_) for x_ in local_x])
            local_x = student_transf(local_x)
            local_x = local_x/local_x.max()

           

            ridx = torch.randperm(local_x.shape[0])
            
            if ind==0: 
                ridx = ridx[ridx!=0]
                local_x[0] = rgb2hsv(local_x[0].unsqueeze(0)).squeeze(0)
            
            
            if ind%2==0: local_x[ridx[:int(local_x.shape[0]*3/5)], :] = rgb2hsv(local_x[ridx[:int(local_x.shape[0]*3/5)], :])
            else: local_x[ridx[:int(local_x.shape[0]*2/5)], :] = rgb2hsv(local_x[ridx[:int(local_x.shape[0]*2/5)], :])
            
            local_x = local_x*255.0
            
            student_model.zero_grad()
            optimizer.zero_grad()

            loss, loss1, loss2, loss3, mask_for_masking = student_model(local_x, local_y, label_idx, local_w, ind, epoch, train=True)
            if ind==0 and epoch>=1:
                assert np.array(mask_for_masking).any()!=None

                mask_expls.append(mask_for_masking)
            
            
            tot_loss += loss.item()
            tot_loss1 += loss1.item()
            tot_loss2 += loss2.item()
            tot_loss3 += loss3.item()
            
            loss.backward()
            optimizer.step()
        logger.info("Training done...Validation starting")
        
        student_model.eval()
        
        with torch.no_grad():
            outputs, targets = [], []
            for ind, (local_x, local_y, local_w) in enumerate(val_data_loader):

                
                local_x = student_transf(local_x)
                local_x = local_x/local_x.max()
                ridx = torch.randperm(local_x.shape[0])[:int(local_x.shape[0]/2)]
                local_x[ridx, :] = rgb2hsv(local_x[ridx, :])
                local_x = local_x*255.0
                loss, loss1, loss2, loss3, _ = student_model(local_x, local_y, label_idx, local_w,ind,epoch, train=False)

                val_loss += loss.item()
                val_loss1 += loss1.item()
                val_loss2 += loss2.item()
                val_loss3 += loss3.item()
                
        e_log = epoch + 1
        train_loss = tot_loss/len(train_data_loader)
        train_loss1 = tot_loss1/len(train_data_loader)
        train_loss2 = tot_loss2/len(train_data_loader)
        train_loss3 = tot_loss3/len(train_data_loader)
        val_loss_log = val_loss/len(val_data_loader)
        val_loss_log1 = val_loss1/len(val_data_loader)
        val_loss_log2 = val_loss2/len(val_data_loader)
        val_loss_log3 = val_loss3/len(val_data_loader)
        loss_epochs.append(train_loss)
        loss_epochs1.append(train_loss1)
        loss_epochs2.append(train_loss2)
        loss_epochs3.append(train_loss3)
        val_loss_epochs.append(val_loss_log)
        val_loss_epochs1.append(val_loss_log1)
        val_loss_epochs2.append(val_loss_log2)
        val_loss_epochs3.append(val_loss_log3)
        logger.info(f"Epoch {e_log}, \
                    Training Loss {train_loss}")
        logger.info(f"Epoch {e_log}, \
                    Validation Loss {val_loss_log}")
    
    np.save('results/plots/train_loss.npy', loss_epochs)
    np.save('results/plots/train_loss1.npy', loss_epochs1)
    np.save('results/plots/train_loss2.npy', loss_epochs2)
    np.save('results/plots/train_loss3.npy', loss_epochs3)
    np.save('results/plots/val_loss.npy', val_loss_epochs)
    np.save('results/plots/val_loss1.npy', val_loss_epochs1)
    np.save('results/plots/val_loss2.npy', val_loss_epochs2)
    np.save('results/plots/val_loss3.npy', val_loss_epochs3)
    plt.figure()
    plt.plot(loss_epochs)
    plt.plot(val_loss_epochs)
    plt.legend(['train loss', 'val loss'])
    plt.savefig('results/plots/lossplot.png')
    plt.close()
    mask_expls = np.array(mask_expls).mean(axis=0)
    mask_expls = (mask_expls - mask_expls.min())/(mask_expls.max() - mask_expls.min())
    return mask_expls