import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm as tqdm

# class to represent dataset
class SCADataset():
  
    def __init__(self, data):
        
        self.x = data[0].astype(np.float32)
        self.y = data[2]
        self.n_samples = data[0].shape[0] 
      
    # support indexing such that dataset[i] can 
    # be used to get i-th sample
    def __getitem__(self, index):
        return self.x[index], self.y[index]
        
    # we can call len(dataset) to return the size
    def __len__(self):
        return self.n_samples

class SCABPDataset():
  
    def __init__(self, data):
        
        self.x = data[0].astype(np.float32)
        self.bp0 = data[1][0]
        self.bp1 = data[1][1]
        self.bp2 = data[1][2]
        self.bp3 = data[1][3]
        self.y = data[2]
        self.n_samples = data[0].shape[0] 
      
    # support indexing such that dataset[i] can 
    # be used to get i-th sample
    def __getitem__(self, index):
        return self.x[index], self.bp0[index], self.bp1[index], self.bp2[index], self.bp3[index] ,self.y[index]
        
    # we can call len(dataset) to return the size
    def __len__(self):
        return self.n_samples

class Flatten(torch.nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)
    
class ConvBlock(nn.Module): 
    def __init__(self):
        super(ConvBlock, self).__init__()
        
        self.conv_layers = nn.Sequential(
                                nn.BatchNorm1d(1),
                                nn.Conv1d(1, 512, kernel_size=3, stride=1,padding='same', bias=False),
                                torch.nn.MaxPool1d(kernel_size=2, stride=3),
                                nn.BatchNorm1d(512),
                                nn.Conv1d(512, 256, kernel_size=3, stride=1,padding='same', bias=False),
                                torch.nn.MaxPool1d(kernel_size=2, stride=3),
                                nn.BatchNorm1d(256),
                                nn.Conv1d(256, 128, kernel_size=3, stride=1,padding='same', bias=False),
                                torch.nn.MaxPool1d(kernel_size=2, stride=3),
                                nn.BatchNorm1d(128),
                                nn.Conv1d(128, 64, kernel_size=3, stride=1,padding='same', bias=False),
                                torch.nn.MaxPool1d(kernel_size=2, stride=3),
                                nn.BatchNorm1d(64),
                                nn.Flatten(),
                                nn.Linear(448,1024),
                                nn.BatchNorm1d(1024),
                                nn.Dropout1d(0.2),
                                nn.Linear(1024, 512),
                                nn.BatchNorm1d(512),
                                nn.Linear(512, 256),
                                nn.BatchNorm1d(256),
                                nn.Dropout1d(0.2),
                                nn.Linear(256, 128),
                                nn.BatchNorm1d(128))
        
    def forward(self, x):
        out = self.conv_layers(x)

        return out
class ConvBlockExtra(nn.Module): 
    def __init__(self):
        super(ConvBlockExtra, self).__init__()
        
        self.conv_layers = nn.Sequential(
                                nn.BatchNorm1d(1),
                                nn.Conv1d(1, 512, kernel_size=3, stride=1,padding='same', bias=False),
                                torch.nn.MaxPool1d(kernel_size=2, stride=3),
                                nn.BatchNorm1d(512),
                                nn.Conv1d(512, 256, kernel_size=3, stride=1,padding='same', bias=False),
                                torch.nn.MaxPool1d(kernel_size=2, stride=3),
                                nn.BatchNorm1d(256),
                                nn.Conv1d(256, 128, kernel_size=3, stride=1,padding='same', bias=False),
                                torch.nn.MaxPool1d(kernel_size=2, stride=3),
                                nn.BatchNorm1d(128),
                                nn.Conv1d(128, 64, kernel_size=3, stride=1,padding='same', bias=False),
                                torch.nn.MaxPool1d(kernel_size=2, stride=3),
                                nn.BatchNorm1d(64),
                                nn.Flatten(),
                                nn.Linear(448,1024),
                                nn.BatchNorm1d(1024),
                                nn.Dropout1d(0.2),
                                nn.Linear(1024, 1024),
                                nn.BatchNorm1d(1024),
                                nn.Dropout1d(0.2),
                                nn.Linear(1024, 512),
                                nn.BatchNorm1d(512),
                                nn.Linear(512, 256),
                                nn.BatchNorm1d(256),
                                nn.Dropout1d(0.1),
                                nn.Linear(256, 128),
                                nn.BatchNorm1d(128),
                                nn.Linear(128, 1024),
                                nn.BatchNorm1d(1024),
                                nn.Dropout1d(0.2),
                                nn.Linear(1024, 512),
                                nn.BatchNorm1d(512),
                                nn.Linear(512, 256),
                                nn.BatchNorm1d(256),
                                nn.Dropout1d(0.2),
                                nn.Linear(256, 128),
                                nn.BatchNorm1d(128))
        
    def forward(self, x):
        out = self.conv_layers(x)

        return out
class FeatureBlock(nn.Module): 
    def __init__(self):
        super(FeatureBlock, self).__init__()
        
        self.layers = nn.Sequential(
                                nn.Linear(3458,1024),
                                nn.BatchNorm1d(1024),
                                nn.Dropout1d(0.2),
                                nn.Linear(1024, 1024),
                                nn.BatchNorm1d(1024),
                                nn.Dropout1d(0.2),
                                nn.Linear(1024, 512),
                                nn.BatchNorm1d(512),
                                nn.Linear(512, 128),
                                nn.BatchNorm1d(128),
                                nn.Dropout1d(0.1))
    def forward(self, x):
        out = self.layers(x)

        return out
class ClassificationBlock(nn.Module): 
    def __init__(self):
        super(ClassificationBlock, self).__init__()
        
        self.layers = nn.Sequential(
                                nn.Linear(128,1024),
                                nn.BatchNorm1d(1024),
                                nn.Dropout1d(0.2),
                                nn.Linear(1024, 1024),
                                nn.BatchNorm1d(1024),
                                nn.Dropout1d(0.2),
                                nn.Linear(1024, 512),
                                nn.BatchNorm1d(512),
                                nn.Linear(512, 256),
                                nn.BatchNorm1d(256),
                                nn.Dropout1d(0.1),
                                nn.Linear(256, 128),
                                nn.BatchNorm1d(128),
                                nn.Linear(128,3329),
                                nn.Softmax())
    def forward(self, x):
        out = self.layers(x)

        return out
#CNNC with 2C
class CNNC(nn.Module): 
    def __init__(self):
        super(CNNC, self).__init__()
        
        self.ConvBlock = ConvBlockExtra()
        self.FeatureBlock = FeatureBlock()
        self.ClBlock = ClassificationBlock()
    
    def forward(self, x):
        out = self.ConvBlock(x[0])
        #Concatenation x1 = bp0, x2 = bp1
        #outcat0 = torch.cat((out,x[1]), 1)
        #outcat1 = torch.cat((out,x[2]), 1)
        
        #out0 = self.FeatureBlock(outcat0)
        #out1 = self.FeatureBlock(outcat1)
        
        #outcat = torch.cat((out0,out1), 1)
        
        out = self.ClBlock(out)
        
        return out

def test_model(Testdataloader, model):
    test_batch = []
    avg_label = []
    realkey=1733
    for i, (inp, bp0,bp1,bp2,bp3, label) in enumerate(tqdm(Testdataloader)):
        inp = inp.unsqueeze(1)
        bp0 = F.one_hot(bp0.long(), num_classes=3330)
        bp1 = F.one_hot(bp1.long(), num_classes=3330)
        inp = inp.to(device)
        bp0 = bp0.to(device)
        bp1 = bp1.to(device)
        label = label.long().to(device)
        out = model((inp,bp0,bp1))
        out_label = torch.mean(torch.argmax(out, dim = 1).float())
        out_acc = torch.abs(realkey-out_label)
        avg_label.append(out_acc.detach().cpu().numpy())
        if i == 3:
            break
        for o in out:
            test_batch.append(o)
    print(np.mean(avg_label))
    #out = test_batch[0]
    #print(out.shape)
    #print(torch.argmax(out[0]))
    lpsums = np.zeros(3329)
    realkey = 1733
    traces = test_batch[:200]
    tr = traces[0].detach().cpu().numpy()
    print(np.log(tr))
    #rank_matrix = np.tile(0, (20, 200))
    rankmat_byKey = np.zeros((200))
    i = 0
    for trace in traces:
        lpsums += np.log(trace.detach().cpu().numpy())
        rnk_byKey = sum(lpsums > lpsums[realkey])
        rankmat_byKey[i] = rnk_byKey
        i = i + 1
    print(rankmat_byKey[-10:])
    print(np.argmax(lpsums))
    
    return rankmat_byKey


def train_model(model, SCABPdataloader, SCABPTestdataloader, num_epoch):
    for ep in range(num_epoch):
        print("Epoch: " + str(ep))
        idx = 0
        accuracy = 0
        for i, (inp, bp0,bp1,bp2,bp3, label) in enumerate(tqdm(SCABPdataloader)):
            model.train()
            inp = inp.unsqueeze(1)
            bp0 = F.one_hot(bp0.long(), num_classes=3330)
            bp1 = F.one_hot(bp1.long(), num_classes=3330)
            inp = inp.to(device)
            bp0 = bp0.to(device)
            bp1 = bp1.to(device)
            label = label.long().to(device)
            optimizer.zero_grad()
            out = model((inp,bp0,bp1))
            pred = torch.argmax(out, dim=1)
            accuracy += acc(pred, label)
            loss = criterion(out, label)
            loss.backward()
            # Adjust learning weights
            optimizer.step()
            idx = i
        accuracy = (accuracy / idx).detach().cpu().numpy()
        print(accuracy)
        model.eval()
        rankmat = test_model(SCABPTestdataloader, model)
        if rankmat[-1] < 20:
            joblib.dump(model, 'model'+str(ep)+"_"+str(rankmat[-1])+'.pth')

    return model

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--data_path',  type=str, help='data path')
    parser.add_argument('--test_path', type=int, help='path for test set')
    parser.add_argument('--batch_size', type=int, help='batch size')
    parser.add_argument('--num_epoch', type=int, help='number of epoch')
    parser.add_argument('--num_sample', type=int, help='number of sample')
    return parser


if __name__ == '__main__':
    parser = parse_arguments()
    args = parser.parse_args()
    data_path = args.data_path
    data = np.load(data_path)

    num_sample = args.num_sample
    num_epoch = args.num_epoch
    BATCH_SIZE = args.batch_size
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    #trace_profiling: 100000 x 600
    #bp_profiling: 100000 x 4
    #skpv_profiling: 100000 aka Target Label
    trace_input = data['data']
    bp_input = data['bp']
    label_set = data['label']
    print(bp_input.shape)

    input_data = [trace_input[:num_sample], bp_input[:num_sample], label_set[:num_sample]]

    SCAdataset = SCADataset(input_data)
    SCAdataloader = DataLoader(dataset=SCAdataset, batch_size=BATCH_SIZE, shuffle=True)
    SCABPdataset = SCABPDataset(input_data)
    SCABPdataloader = DataLoader(dataset=SCABPdataset, batch_size=BATCH_SIZE, shuffle=True)

    test_data = np.load(args.test_path)
    test_trace = test_data['data']
    test_bp = test_data['bp']
    test_label = test_data['label']
    test_data = (test_trace, test_bp, test_label)
    SCABPTestdataset = SCABPDataset(test_data)
    SCABPTestdataloader = DataLoader(dataset=SCABPTestdataset, batch_size=BATCH_SIZE, shuffle=True)
    from torch import tensor
    from torchmetrics.classification import MulticlassAccuracy
    #target = tensor([2, 1, 0, 0])
    #preds = tensor([2, 1, 0, 1])
    acc = MulticlassAccuracy(num_classes=3329).to(device)
    #metric(preds, target)

    mca = MulticlassAccuracy(num_classes=3329, average=None).to(device)
    #mca(preds, target)
    import joblib
    #2C Model
    model = CNNC()
    model.to(device)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-6)
    criterion = torch.nn.CrossEntropyLoss()

    train_model(model, SCABPdataloader, SCABPTestdataloader, num_epoch)