import os
import numpy as np
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
import random
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn.init as init
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import sys
sys.setrecursionlimit(15000)

from utils.softdtw_cuda import SoftDTW
from model.deep_FeatureExtract import FCN, ResNet
from utils.feature_visualize import *

from captum.attr import Occlusion
from matplotlib.collections import LineCollection
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.figure import figaspect
import matplotlib.cm as cm

from src.lrp_for_model import construct_lrp

random_seed=10
torch.manual_seed(random_seed) # for torch.~~
torch.backends.cudnn.deterministic = True # for deep learning CUDA library
torch.backends.cudnn.benchmark = False # for deep learning CUDA library
np.random.seed(random_seed) # for numpy-based backend, scikit-learn
random.seed(random_seed) # for python random library-based e.g., torchvision
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU

num_workers = 4
pin_memory = True
device = 'cuda'
   
    
class ConvSwitch(nn.Module):
    def __init__(self, input_size, time_length, classes, data_type, args):
        super(ConvSwitch, self).__init__()
        self.input_size = input_size
        self.time_length = time_length
        self.classes = classes
        self.data_type = data_type
        
        self.pool = args.pool
        self.pool_op = args.pool_op
        self.switch_op = args.switch_op
        
        if classes < 4:
            protos = 4
        elif classes > 10:
            protos = 10
        else:
            protos = classes
            
        self.protos_num = protos
        self.dtp_distance = args.cost_type
        self.deep_extract= args.deep_extract
        
        self.protos = nn.Parameter(torch.zeros(256, self.protos_num), requires_grad=True)
        self.softdtw = SoftDTW(use_cuda=True, gamma=1.0, cost_type=args.cost_type, normalize=False)
        
        self.switch = nn.Parameter(torch.ones(1, self.protos_num*3), requires_grad=True)

        # Layer 1: Just a conventional Conv2D layer
        if args.deep_extract=='ResNet':
            self.conv1 = ResNet(classes, 0, input_size)
        elif args.deep_extract=='shallow':
            self.conv1 = nn.Conv1d(in_channels=input_size,
                                 out_channels=256,
                                 kernel_size=(1))
        else:
            self.conv1 = FCN(classes, 0, input_size)  

        # Decoder network.
        self.decoder = nn.Sequential(
            nn.Linear(256*self.protos_num, 512), 
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, classes),
        )
        
        self.encoding = nn.Conv2d(256, 1, 1)
        self.ensem_decoder = nn.Sequential(
            nn.Linear(256*(self.protos_num*3), 512), 
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, classes),
        )

        self.relu = nn.ReLU()
        self.kl = nn.KLDivLoss(reduction="batchmean")
    
    # global temporal pooling
    def gtpool(self, h, op):
        if op == 'AVG':
            return torch.mean(h, dim=2)
        if op == 'SUM':
            return torch.sum(h, dim=2)
        elif op == 'MAX':
            return torch.max(h, dim=2)[0]
    
    #static temporal pooling
    def stpool(self, h, n, op):
        segment_sizes = [int(h.shape[2]/n)] * n
        segment_sizes[-1] += h.shape[2] - sum(segment_sizes)
       
        hs = torch.split(h, segment_sizes, dim=2)
        if op == 'AVG':
            hs = [h_.mean(dim=2, keepdim=True) for h_ in hs]
        elif op == 'SUM':
            hs = [h_.sum(dim=2, keepdim=True) for h_ in hs]
        elif op == 'MAX':
            hs = [h_.max(dim=2)[0].unsqueeze(dim=2) for h_ in hs]
        hs = torch.cat(hs, dim=2)
        return hs
    
    #dynamic temporal pooling
    def dtpool(self, h, op, num=0, class_num=0, visualize=False, result_folder=None):
        h_origin = h
        A = self.softdtw.align(self.protos.repeat(h.shape[0], 1, 1), h)
        if visualize==True:
            visualize_alignmatrix(A, num, class_num, result_folder)

        if op == 'AVG':
            A = A.clone()
            A /= A.sum(dim=2, keepdim=True)
            h = torch.bmm(h, A.transpose(1, 2))
        elif op == 'SUM':
            h = h.unsqueeze(dim=2) * A.unsqueeze(dim=1)
            h = h.sum(dim=3)
        elif op == 'MAX':
            h = h.unsqueeze(dim=2) * A.unsqueeze(dim=1)
            h = h.max(dim=3)[0]
        return h
    
    
    def get_htensor(self, x):
        h = F.relu(self.conv1(x))
        return h
    
    def init_protos(self, data_loader):
        for itr, batch in enumerate(data_loader):
            data = batch['data'].cuda()
            h = self.get_htensor(data).squeeze(2)
            self.protos.data += self.stpool(h, self.protos_num, 'AVG').mean(dim=0)
            
        self.protos.data /= len(data_loader)
 
    
    def switch_pool(self, h, op, num=0, class_num=0, visualize=False, result_folder=None, name=None, count=None):
        out1 = self.gtpool(h, op).unsqueeze(2).repeat(1,1,self.protos_num)
        out2 = self.stpool(h, self.protos_num, op)
        out3 = self.dtpool(h, op)
        
        op=0
        concat_out = torch.cat([out1, out2, out3], dim=-1)
        raw_attn = self.switch.repeat(h.shape[0], 1, 1)

        encode_attn = torch.matmul(concat_out.unsqueeze(3), raw_attn.unsqueeze(1))

        attn = F.softmax(self.encoding(encode_attn), dim=-1).squeeze(1)

        if visualize==True:
            visualize_attn(attn, self.protos_num, num, class_num, result_folder, name, count)

        ensemble = torch.matmul(concat_out.unsqueeze(2), attn.unsqueeze(1))

        if self.pool_op =='MAX':
            ind = torch.mean(torch.max(attn, dim=2)[1].squeeze(1).float())
            if ind.item() <self.protos_num+1:
                tmp = out1
                op = 0
            elif ind.item() >=self.protos_num+1 and ind.item()<=self.protos_num*2+1:
                tmp = out2
                op = 1
            else:
                tmp = out3
                op=2

        elif self.pool_op == 'AVG':
            ind = torch.cat([torch.mean(attn[:, :, :self.protos_num], dim=2), 
                             torch.mean(attn[:, :, self.protos_num:self.protos_num*2], dim=2), 
                             torch.mean(attn[:, :, self.protos_num*2:], dim=2)], dim=1)
            ind = torch.mean(torch.max(ind, dim=1)[1].float())

            if ind.item() < 0.6 :
                tmp = out1
                op = 0
            elif 0.6 <= ind.item() and ind.item() <1.6:
                tmp = out2
                op = 1
            else:
                tmp = out3
                op=2
                                
        out = tmp
        
        return ensemble, out, op, concat_out, attn
            
    
    def compute_aligncost(self, h):
        cost = self.softdtw(self.protos.repeat(h.shape[0], 1, 1), h.detach())
        return cost.mean() / h.shape[2]
    
    def compute_perspectivecost(self, ensem, one):
        diverse = self.ensem_decoder(ensem.reshape(ensem.shape[0], -1))
        one = self.decoder(one.reshape(one.shape[0], -1))
        cost = self.kl(F.log_softmax(one), F.softmax(diverse))
        return diverse, cost
    
    def compute_attentioncost(self, concat, h):
        cost = torch.bmm(h.transpose(1,2),concat.squeeze(2))
        return cost.mean() / h.shape[2]
        

    def forward(self, x, y=None, visualize=False, num = 0, class_num=0, result_folder=None, name=None, count=None):
        if visualize==True:
            x = F.relu(self.conv1(x)).squeeze(2)
            #layer_visualize(x, 0, 'feature_extracted', num, result_folder)
            
            ensem, one, op, raw, attn = self.switch_pool(x, 'MAX', num=0, class_num=0, visualize=True, result_folder=result_folder, name=name, count=count)
            
            out = one.reshape(one.shape[0], -1)            
            out = self.decoder(out)
            
            if y is None:
                return out
            else:
                return x, out, ensem, one, op
        
        else:    
            x = F.relu(self.conv1(x)).squeeze(2)
            ensem, one, op, raw, attn = self.switch_pool(x, 'MAX')
            
            out = one.reshape(one.shape[0], -1)            
            out = self.decoder(out)
            
            if y is None:
                return out
            else:
                return x, out, ensem, one, op, attn
    
    
    

def train_ConvSwitch(args, train_dataset, valid_dataset, test_dataset, num, data_type, model, result_folder):
    batch_size = int(min(len(train_dataset)/10, args.batch_size))
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    model.cuda()
    
    weight1 = torch.Tensor(train_dataset.weight).cuda()
    ce = torch.nn.CrossEntropyLoss(weight=weight1) 
    optim_h = torch.optim.Adam(model.parameters(), lr=args.lr)
    optim_p1 = torch.optim.Adam([model.protos], lr=args.lr)
    optim_a1 = torch.optim.Adam([model.switch], lr=args.lr)
    
    model.init_protos(train_loader)
    
    train_loss_list = []
    
    valid_loss_list = []
    valid_acc_list=[]
    performance = []
    
    min_loss = np.inf
    cnt = 0
    for epoch in range(args.num_epoch):
        model.train()
        total_step = len(train_loader)
        total, total_ce_loss, total_dtw_loss = 0, 0, 0
       
        for batch in train_loader:
            data, labels = batch['data'].cuda(), batch['labels'].cuda()
            x1, logits, ensem, one, _, _ = model(data, y=1)
            diverse, perspective_loss = model.compute_perspectivecost(ensem, one)
            #"""
            attn_loss = model.compute_attentioncost(ensem, one)
            optim_a1.zero_grad()
            attn_loss.backward(retain_graph=True)
            optim_a1.step()
            #"""
            raw_loss = ce(logits, labels)          
            ce_loss = raw_loss + (ce(diverse, labels)+ perspective_loss)*1e-4  
            optim_h.zero_grad()
            ce_loss.backward(retain_graph=True)
            optim_h.step()

            dtw_loss = model.compute_aligncost(x1)
            optim_p1.zero_grad()
            dtw_loss.backward(retain_graph=True)
            optim_p1.step()
                     
            
            with torch.no_grad():
                total_ce_loss += ce_loss.item() * data.size(0)
                total_dtw_loss += dtw_loss.item() * data.size(0)
                total += data.size(0)
            
        train_loss = total_ce_loss / total
        train_loss_list.append(train_loss)
       
        predictions = []
        answers = []
        model.eval()
        correct_val, val_total = 0, 0
        total_step = len(valid_loader)
        total_ce_loss, total_dtw_loss = 0, 0
        with torch.no_grad():
            for batch in valid_loader:
                data, labels = batch['data'].cuda(), batch['labels'].cuda()
                answers.extend(labels.detach().cpu().numpy())
                
                x1, logits, ensem, one, _, _ = model(data, y=1)
                _, predicted = torch.max(logits, 1)
                predictions.extend(predicted.detach().cpu().numpy())
                
                val_total += data.size(0)
                correct_val += (predicted == labels).sum().item()
                
                raw_loss = ce(logits, labels)
                diverse, perspective_loss = model.compute_perspectivecost(ensem, one)
                ce_loss = raw_loss #+ (ce(diverse, labels)+perspective_loss)*1e-4

                total_ce_loss += ce_loss.item() * data.size(0)
            
            valid_loss_list.append(total_ce_loss/val_total)
            valid_acc_list.append(correct_val/val_total)
            
        
        if (epoch==0) or (epoch>0 and (min(valid_loss_list[:-1])>valid_loss_list[-1])):
            torch.save({
                'epoch': epoch,
                'loss' : valid_loss_list[-1],
                'acc' : valid_acc_list[-1],
                'model_state_dict' : model.state_dict(),
                'optimizer_state_dict' : optim_h.state_dict(),
                'criterion' : ce
            }, os.path.join(result_folder, f'{args.model}-best-{data_type}-{num}.pt'))
                        
            predictions = []
            answers = []
            correct, test_total = 0, 0
            total_step = len(test_loader)
            total_ce_loss, total_dtw_loss = 0, 0
            with torch.no_grad():
                model.eval()
                for batch in test_loader:
                    data, labels = batch['data'].cuda(), batch['labels'].cuda()
                    answers.extend(labels.detach().cpu().numpy())

                    x1, logits, ensem, one, _, _ = model(data, y=1)
                    _, predicted = torch.max(logits, 1)
                    predictions.extend(predicted.detach().cpu().numpy())

                    test_total += data.size(0)
                    correct += (predicted == labels).sum().item()
                    
                    ce_loss = ce(logits, labels)
                    total_ce_loss += ce_loss.item() * data.size(0)

            print('\tEpoch [{:3d}/{:3d}], Test Loss: {:.4f}, {:.4f}, Test Accuracy: {:.4f}'
            .format(epoch+1, args.num_epoch, total_ce_loss/test_total, 0/test_total, correct/test_total))
            
            performance = [total_ce_loss/test_total,
                           correct/test_total, 
                           f1_score(answers, predictions, average='macro'),
                           f1_score(answers, predictions, average='micro'),
                           f1_score(answers, predictions, average='weighted'),
                           np.mean(f1_score(answers, predictions, average=None)),
                           None
                          ]
    
    print('The Best Test Accuracy: {:.4f}'.format(correct/test_total))
    
    return performance, train_loss_list, valid_loss_list, valid_acc_list




def visualize_ConvSwitch(args, train_dataset, valid_dataset, test_dataset, num, data_type, model, model_folder, result_folder, name):
    batch_size = int(min(len(train_dataset)/10, args.batch_size))
    class_num = train_dataset.num_classes
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    checkpoint = torch.load(os.path.join(model_folder, f'{args.model}-best-{data_type}-{num}.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.cuda()
    model = model.eval()
    
    for k, loader in enumerate([test_loader]):
        result_folder_sub = result_folder
        os.makedirs(result_folder_sub, exist_ok=True)
        with torch.no_grad():
            for i, batch in enumerate(loader):
                x, y = batch['data'].cuda(), batch['labels'].cuda()

                if i == 0:
                    
                    _ = model(x, visualize=True, num = 0, class_num=0, result_folder=result_folder_sub, name=name)
                    break
                    
def LRP_ConvSwitch(args, train_dataset, valid_dataset, test_dataset, num, data_type, model, model_folder, result_folder, name):
    batch_size = int(min(len(train_dataset)/10, args.batch_size))
    class_num = train_dataset.num_classes
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    perf = pd.read_csv(os.path.join(model_folder, f'{args.model}_{args.deep_extract}_{args.pool}_{args.pool_op}_{args.switch_op}_uni_performance.csv'))
    acc = perf.iloc[num, 2]
    
    checkpoint = torch.load(os.path.join(model_folder, f'{args.model}-best-{data_type}-{num}.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.cuda()
    model = model.eval()
    n = model.protos_num
    lrp_model = construct_lrp(args, model, "cuda")
    
    def compute_lrp(lrp_model, x, y, class_specific):
        # computation of lrp 
        output = lrp_model.forward(x, y=y, class_specific=class_specific)
        all_relevnace = output['all_relevnaces']
        return all_relevnace

    for k, loader in enumerate([train_loader, valid_loader, test_loader]):
        result_folder_sub = result_folder
        os.makedirs(result_folder_sub, exist_ok=True) 
        predictions = []
        lrp_list = []
                
        for i, batch in enumerate(loader):
            data, labels = batch['data'].cuda(), batch['labels'].cuda()
            classes = model(data)
            #_, A = model.compute_gradcam(data, labels)
            predictions.extend(torch.max(classes,1)[1].detach().cpu().numpy())
            lrps = compute_lrp(lrp_model, data.unsqueeze(2), labels, class_specific=True)[::-1]   # from Relevance 0,1,2,3,...
            lrp = lrps[0].squeeze(2)
            lrp_list.append(lrp)
            """
            lrp_list = torch.concat(lrp_list, dim=0)
            lrp_list_ = lrp_list.detach().cpu().numpy()
            relevance_ =np.abs(lrp_list_).mean(axis=1)
            attributions_occ_list = relevance_[:, 4:-4]
            tmp = pd.DataFrame(attributions_occ_list)
            tmp.to_csv(os.path.join(result_folder_sub,f'{name}_RelScore_{k}.csv'))
            """
            lrp_list_ = lrp_list[0].detach().cpu().numpy()
            relevance_ =np.abs(lrp_list_).mean(axis=1)
            attributions_occ_list = relevance_[:, 4:-4]
            
            if batch_size==1:
                batch_size_ = batch_size+1
            else:
                batch_size_=batch_size

            fig, axs = plt.subplots(1, figsize=(6,2), sharex=True, sharey=True)

            for k, ind in enumerate(range(1)):
                x = np.linspace(0, len(data.detach().cpu().numpy()[ind][0]), 
                len(data.detach().cpu().numpy()[ind][0]))
                y = data.detach().cpu().numpy()[ind][0]
                points = np.array([x, y]).T.reshape(-1, 1, 2)
                segments = np.concatenate([points[:-1], points[1:]], axis=1)
                occlusion = attributions_occ_list[ind].reshape(-1, 1)[:, 0]

                norm = plt.Normalize(np.array(attributions_occ_list)[ind,:].min(),
                                     np.array(attributions_occ_list)[ind,:].max())

                lc = LineCollection(segments, cmap='Reds', norm=norm)
                lc.set_array(occlusion)
                lc.set_linewidth(3)
                line = axs.add_collection(lc)

                axs.set_xlim(x.min(), x.max())
                axs.set_ylim(y.min()-0.5, y.max()+0.5)

                axs.set_title(f'SoM-TP', color='k', fontsize=25)
                axs.tick_params(axis='x', size=1, color='white')
                axs.tick_params(axis='y', size=1, color='white')
                axs.set_xticklabels([''])
                axs.set_yticklabels([''])
                fig.tight_layout()
                plt.savefig(os.path.join(result_folder_sub,f'{name}_LRP.pdf'))
                
                print(labels[ind], np.mean(occlusion), np.sqrt(np.var(occlusion)))
                break
            break                 