import os
import numpy as np
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
import random
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn.init as init
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import sys
sys.setrecursionlimit(15000)

from utils.softdtw_cuda import SoftDTW
from model.deep_FeatureExtract import FCN, ResNet
from utils.feature_visualize import *

from captum.attr import Occlusion
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.figure import figaspect
import matplotlib.cm as cm

from src.lrp_for_model import construct_lrp

random_seed=10
torch.manual_seed(random_seed) # for torch.~~
torch.backends.cudnn.deterministic = True # for deep learning CUDA library
torch.backends.cudnn.benchmark = False # for deep learning CUDA library
np.random.seed(random_seed) # for numpy-based backend, scikit-learn
random.seed(random_seed) # for python random library-based e.g., torchvision
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU

num_workers = 4
pin_memory = True
device = 'cuda'
   
    
class ConvPool(nn.Module):
    def __init__(self, input_size, time_length, classes, data_type, args):
        super(ConvPool, self).__init__()
        self.input_size = input_size
        self.time_length = time_length
        self.classes = classes
        self.data_type = data_type
        
        self.pool = args.pool
        self.pool_op = args.pool_op
        
        if classes < 4:
            protos = 4
        elif classes > 10:
            protos = 10
        else:
            protos = classes
            
        self.protos_num = protos
        self.dtp_distance = args.cost_type
        self.deep_extract= args.deep_extract
        
        self.protos = nn.Parameter(torch.zeros(256, self.protos_num), requires_grad=True)
        self.softdtw = SoftDTW(use_cuda=True, gamma=1.0, cost_type=args.cost_type, normalize=False)

        # Layer 1: Just a conventional Conv2D layer
        if args.deep_extract=='ResNet':
            self.conv1 = ResNet(classes, 0, input_size)
        elif args.deep_extract=='shallow':
            self.conv1 = nn.Conv1d(in_channels=input_size,
                                 out_channels=256,
                                 kernel_size=(1))
        else:
            self.conv1 = FCN(classes, 0, input_size)  

        # Decoder network.
        if self.pool == 'GTP':
            self.decoder = nn.Sequential(
                nn.Linear(256, 512), 
                nn.ReLU(inplace=True),
                nn.Linear(512, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, classes),
            )
        else:
            self.decoder = nn.Sequential(
                nn.Linear(256*self.protos_num, 512), 
                nn.ReLU(inplace=True),
                nn.Linear(512, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, classes),
            )

        self.relu = nn.ReLU()
    
    # global temporal pooling
    def gtpool(self, h, op):
        if op == 'AVG':
            return torch.mean(h, dim=2)
        if op == 'SUM':
            return torch.sum(h, dim=2)
        elif op == 'MAX':
            return torch.max(h, dim=2)[0]
    
    #static temporal pooling
    def stpool(self, h, n, op):
        segment_sizes = [int(h.shape[2]/n)] * n
        segment_sizes[-1] += h.shape[2] - sum(segment_sizes)
       
        hs = torch.split(h, segment_sizes, dim=2)
        if op == 'AVG':
            hs = [h_.mean(dim=2, keepdim=True) for h_ in hs]
        elif op == 'SUM':
            hs = [h_.sum(dim=2, keepdim=True) for h_ in hs]
        elif op == 'MAX':
            hs = [h_.max(dim=2)[0].unsqueeze(dim=2) for h_ in hs]
        hs = torch.cat(hs, dim=2)
        return hs
    
    #dynamic temporal pooling
    def dtpool(self, h, op, num=0, class_num=0, visualize=False, result_folder=None, name=None):
        h_origin = h
        A = self.softdtw.align(self.protos.repeat(h.shape[0], 1, 1), h)
        if visualize==True:
            visualize_alignmatrix(A, A, num, class_num, result_folder, name)

        if op == 'AVG':
            A = A.clone()
            A /= A.sum(dim=2, keepdim=True)
            h = torch.bmm(h, A.transpose(1, 2))
        elif op == 'SUM':
            h = h.unsqueeze(dim=2) * A.unsqueeze(dim=1)
            h = h.sum(dim=3)
        elif op == 'MAX':
            h = h.unsqueeze(dim=2) * A.unsqueeze(dim=1)
            h = h.max(dim=3)[0]
            
        return h
    
    
    def get_htensor(self, x):
        h = F.relu(self.conv1(x))
        return h
    
    def init_protos(self, data_loader):
        for itr, batch in enumerate(data_loader):
            data = batch['data'].cuda()
            h = self.get_htensor(data).squeeze(2)
            self.protos.data += self.stpool(h, self.protos_num, 'AVG').mean(dim=0)
            
        self.protos.data /= len(data_loader)
            
    
    def compute_aligncost(self, h):
        cost = self.softdtw(self.protos.repeat(h.shape[0], 1, 1), h.detach())
        return cost.mean() / h.shape[2]
    
    
    def compute_gradcam(self, x, labels):
        
        def hook_func(grad):
            self.h_grad = grad

        h, logits = self.forward(x, y=1)
        h.register_hook(hook_func)

        self.zero_grad()
        scores = torch.gather(logits, 1, labels.unsqueeze(dim=1))
        scores.mean().backward()
        gradcam = (h * self.h_grad).sum(dim=1, keepdim=True)

        # min-max normalization
        gradcam_min = torch.min(gradcam, dim=2, keepdim=True)[0]
        gradcam_max = torch.max(gradcam, dim=2, keepdim=True)[0]
        gradcam = (gradcam - gradcam_min) / (gradcam_max - gradcam_min) 

        A = self.softdtw.align(self.protos.repeat(h.shape[0], 1, 1), h)
        
        return gradcam, A
        

    def forward(self, x, y=None, visualize=False, num = 0, class_num=0, result_folder=None, name=None):
        if visualize==True:
            x = F.relu(self.conv1(x)).squeeze(2)
            #layer_visualize(x, 0, 'feature_extracted', num, result_folder)
            
            if self.pool == 'GTP':
                out = self.gtpool(x, self.pool_op)
            elif self.pool == 'STP':
                out = self.stpool(x, self.protos_num, self.pool_op)
            else:
                out = self.dtpool(x, self.pool_op, visualize=True, num=num, class_num=self.protos_num, result_folder=result_folder, name=name)
            
            out = out.reshape(out.shape[0], -1)
            out = self.decoder(out)
            
            if y is not None:
                return x
            else:
                return x, out
        
        else:    
            x = F.relu(self.conv1(x)).squeeze(2)
            
            if self.pool == 'GTP':
                out = self.gtpool(x, self.pool_op)
            elif self.pool == 'STP':
                out = self.stpool(x, self.protos_num, self.pool_op)
            else:
                out = self.dtpool(x, self.pool_op)

            out = out.reshape(out.shape[0], -1)
            out = self.decoder(out)
            
            if y is None:
                return out
            else:
                return x, out
    
    
    

def train_ConvPool(args, train_dataset, valid_dataset, test_dataset, num, data_type, model, result_folder):
    batch_size = int(min(len(train_dataset)/10, args.batch_size))
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    model.cuda()
    
    weight1 = torch.Tensor(train_dataset.weight).cuda()
    ce = torch.nn.CrossEntropyLoss(weight=weight1) 
    optim_h = torch.optim.Adam(model.parameters(), lr=args.lr)
    optim_p1 = torch.optim.Adam([model.protos], lr=args.lr)
    
    model.init_protos(train_loader)
    
    train_loss_list = []
    
    valid_loss_list = []
    valid_acc_list=[]
    performance = []
    
    min_loss = np.inf
    cnt = 0
    for epoch in range(args.num_epoch):
        model.train()
        total_step = len(train_loader)
        total, total_ce_loss, total_dtw_loss = 0, 0, 0
       
        for batch in train_loader:
            data, labels = batch['data'].cuda(), batch['labels'].cuda()
            x1, logits = model(data, y=1)

            ce_loss = ce(logits, labels)
            optim_h.zero_grad()
            ce_loss.backward(retain_graph=True)
            optim_h.step()

            dtw_loss = model.compute_aligncost(x1)
            optim_p1.zero_grad()
            dtw_loss.backward(retain_graph=True)
            optim_p1.step()
            
            with torch.no_grad():
                total_ce_loss += ce_loss.item() * data.size(0)
                total_dtw_loss += dtw_loss.item() * data.size(0)
                total += data.size(0)
            
        train_loss = total_ce_loss / total
        train_loss_list.append(train_loss)
       
        predictions = []
        answers = []
        model.eval()
        correct_val, val_total = 0, 0
        total_step = len(valid_loader)
        total_ce_loss, total_dtw_loss = 0, 0
        with torch.no_grad():
            for batch in valid_loader:
                data, labels = batch['data'].cuda(), batch['labels'].cuda()
                answers.extend(labels.detach().cpu().numpy())
                
                x1, logits = model(data, y=1)
                _, predicted = torch.max(logits, 1)
                predictions.extend(predicted.detach().cpu().numpy())
                
                val_total += data.size(0)
                correct_val += (predicted == labels).sum().item()
                
                ce_loss = ce(logits, labels)
                total_ce_loss += ce_loss.item() * data.size(0)
                total_dtw_loss += dtw_loss.item() * data.size(0)
            
            valid_loss_list.append(total_ce_loss/val_total)
            valid_acc_list.append(correct_val/val_total)

        if (epoch==0) or (epoch>0 and (min(valid_loss_list[:-1])>valid_loss_list[-1])):
            torch.save({
                'epoch': epoch,
                'loss' : valid_loss_list[-1],
                'acc' : valid_acc_list[-1],
                'model_state_dict' : model.state_dict(),
                'optimizer_state_dict' : optim_h.state_dict(),
                'criterion' : ce
            }, os.path.join(result_folder, f'{args.model}-best-{data_type}-{num}.pt'))
                        
            predictions = []
            answers = []
            correct, test_total = 0, 0
            total_step = len(test_loader)
            total_ce_loss, total_dtw_loss = 0, 0
            with torch.no_grad():
                model.eval()
                for batch in test_loader:
                    data, labels = batch['data'].cuda(), batch['labels'].cuda()
                    answers.extend(labels.detach().cpu().numpy())

                    x1, logits = model(data, y=1)
                    _, predicted = torch.max(logits, 1)
                    predictions.extend(predicted.detach().cpu().numpy())

                    test_total += data.size(0)
                    correct += (predicted == labels).sum().item()

                    ce_loss = ce(logits, labels)
                    total_ce_loss += ce_loss.item() * data.size(0)
                    total_dtw_loss += dtw_loss.item() * data.size(0)

            print('\tEpoch [{:3d}/{:3d}], Test Loss: {:.4f}, {:.4f}, Test Accuracy: {:.4f}'
            .format(epoch+1, args.num_epoch, total_ce_loss/test_total, 0/test_total, correct/test_total))
            
            performance = [total_ce_loss/test_total,
                           correct/test_total, 
                           f1_score(answers, predictions, average='macro'),
                           f1_score(answers, predictions, average='micro'),
                           f1_score(answers, predictions, average='weighted'),
                           np.mean(f1_score(answers, predictions, average=None)),
                           None
                          ]
    
    print('The Best Test Accuracy: {:.4f}'.format(correct/test_total))
    
    return performance, train_loss_list, valid_loss_list, valid_acc_list




def visualize_ConvPool(args, train_dataset, valid_dataset, test_dataset, num, data_type, model, model_folder, result_folder, name):
    batch_size = int(min(len(train_dataset)/10, args.batch_size))
    class_num = train_dataset.num_classes
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    checkpoint = torch.load(os.path.join(model_folder, f'{args.model}-best-{data_type}-{num}.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.cuda()
    model = model.eval()
    
    for k, loader in enumerate([test_loader]):
        result_folder_sub = result_folder
        os.makedirs(result_folder_sub, exist_ok=True)
        with torch.no_grad():
            for i, batch in enumerate(loader):
                x, y = batch['data'].cuda(), batch['labels'].cuda()

                if i == 0:
                    
                    _, _ = model(x, visualize=True, num = 0, class_num=0, result_folder=result_folder_sub, name=name)
                    break
                    
                    
                    
def LRP_ConvPool(args, train_dataset, valid_dataset, test_dataset, num, data_type, model, model_folder, result_folder, name):
    batch_size = int(min(len(train_dataset)/10, args.batch_size))
    class_num = train_dataset.num_classes
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    perf = pd.read_csv(os.path.join(model_folder, f'{args.model}_{args.deep_extract}_{args.pool}_{args.pool_op}_{args.switch_op}_uni_performance.csv'))
    acc = perf.iloc[num, 2]
    
    checkpoint = torch.load(os.path.join(model_folder, f'{args.model}-best-{data_type}-{num}.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.cuda()
    model = model.eval()
    n = model.protos_num
    lrp_model = construct_lrp(args, model, "cuda")
    
    def compute_lrp(lrp_model, x, y, class_specific):
        # computation of lrp 
        output = lrp_model.forward(x, y=y, class_specific=class_specific)
        all_relevnace = output['all_relevnaces']
        return all_relevnace

    for k, loader in enumerate([train_loader, valid_loader, test_loader]):
        result_folder_sub = result_folder
        os.makedirs(result_folder_sub, exist_ok=True) 
        predictions = []
        lrp_list = []
                
        for i, batch in enumerate(loader):
            data, labels = batch['data'].cuda(), batch['labels'].cuda()
            classes = model(data)
            _, A = model.compute_gradcam(data, labels)
            predictions.extend(torch.max(classes,1)[1].detach().cpu().numpy())
            lrps = compute_lrp(lrp_model, data.unsqueeze(2), labels, class_specific=True)[::-1]   # from Relevance 0,1,2,3,...
            lrp = lrps[0].squeeze(2)
            lrp_list.append(lrp)
            """
            lrp_list = torch.concat(lrp_list, dim=0)
            lrp_list_ = lrp_list.detach().cpu().numpy()
            relevance_ =np.abs(lrp_list_).mean(axis=1)
            attributions_occ_list = relevance_[:, 4:-4]
            tmp = pd.DataFrame(attributions_occ_list)
            tmp.to_csv(os.path.join(result_folder_sub,f'{name}_RelScore_{k}.csv'))
            """
            lrp_list_ = lrp_list[0].detach().cpu().numpy()
            relevance_ =np.abs(lrp_list_).mean(axis=1)
            attributions_occ_list = relevance_[:, 4:-4]
            
            if batch_size==1:
                batch_size_ = batch_size+1
            else:
                batch_size_=batch_size

            fig, axs = plt.subplots(1, figsize=(6,2), sharex=True, sharey=True)

            for k, ind in enumerate(range(1)):
                x = np.linspace(0, len(data.detach().cpu().numpy()[ind][0]), 
                len(data.detach().cpu().numpy()[ind][0]))
                y = data.detach().cpu().numpy()[ind][0]
                points = np.array([x, y]).T.reshape(-1, 1, 2)
                segments = np.concatenate([points[:-1], points[1:]], axis=1)
                occlusion = attributions_occ_list[ind].reshape(-1, 1)[:, 0]

                norm = plt.Normalize(np.array(attributions_occ_list)[ind,:].min(),
                                     np.array(attributions_occ_list)[ind,:].max())

                lc = LineCollection(segments, cmap='Reds', norm=norm)
                lc.set_array(occlusion)
                lc.set_linewidth(3)
                line = axs.add_collection(lc)
                if args.pool == 'STP':
                    segment = data.size(2)//n
                    for j in range(1, n):
                        tmp = segment*j
                        axs.axvline(x=tmp, color='black')
                elif args.pool == 'DTP':
                    A_ = A[ind]
                    time_list = []
                    t = 0
                    for j in range(n):
                        while True:
                            if t == A_.size(1)-1:
                                break
                            if A_[j][t].item() == 0:
                                time_list.append(t-1)
                                break
                            else:
                                t = t+1
                    for tmp in time_list:
                        axs.axvline(x=tmp, color='black')

                #fig.colorbar(line, ax=axs)
                #cbar = axs.collections[0].colorbar
                #cbar.ax.tick_params(labelsize=15)

                axs.set_xlim(x.min(), x.max())
                axs.set_ylim(y.min()-0.5, y.max()+0.5)

                axs.set_title(f'{args.pool}', color='k', fontsize=30)
                axs.tick_params(axis='x', size=1, color='white')
                axs.tick_params(axis='y', size=1, color='white')
                axs.set_yticklabels([''])
                axs.set_xticklabels([''])
                #plt.suptitle(f'STP ({round(acc, 4)})')
                fig.tight_layout()
                plt.savefig(os.path.join(result_folder_sub,f'{name}_LRP_{args.pool}.pdf'))
                
                print(labels[ind], np.mean(occlusion), np.sqrt(np.var(occlusion)))
                break
            break
        #"""