
from torch.utils.data import Dataset,DataLoader
import numpy as np
import torch
from sklearn import manifold
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm import tqdm
from copy import deepcopy

import torchvision.datasets as datasets
from torchvision import models

class PatientECGDataset(Dataset):
    def __init__(self,X,Y,metadata) -> None:
        super().__init__()
        self.X = torch.Tensor(X)
        if self.X.ndim ==2:
            self.X = self.X.unsqueeze(1)
        self.Y = torch.Tensor(Y)
        self.metadata = metadata
        #print(f'Max,Min values for X are {X.max()},{X.min()}')
    
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, index):
        return (self.X[index],self.Y[index],self.metadata[index])

    def get_sample_indexes(self,size,selected_label = None):
        if selected_label is None:
            indexes = np.random.choice(self.__len__(),size= size,replace=False)
        else:
            viable_indexes = torch.argwhere(self.Y==selected_label).detach().cpu().numpy().squeeze()
            if len(viable_indexes)<=size:
                size = viable_indexes -1
            indexes = np.random.choice(viable_indexes,size=size,replace=False)
        
        return indexes

    def get_victim_sample(self,victim_label=None,victim_patient=None):
        print(set(self.metadata))
        viable_label_idx = (self.Y==victim_label).detach().cpu().numpy().squeeze() if victim_label else np.ones((self.__len__()))
        viable_patient_idx = np.array([_==victim_patient for _ in self.metadata]).squeeze() if victim_patient else np.ones((self.__len__()))
        joint_viable_index = viable_patient_idx * viable_label_idx
        joint_viable_index = np.argwhere(joint_viable_index).squeeze()
        return np.random.choice(joint_viable_index,size=1).item()

    def create_poisoned_copy(self,poisoned_vectors,poisoned_indexes):
        X_copy = deepcopy(self.X)
        Y_copy = deepcopy(self.Y)
        metadata_copy = deepcopy(self.metadata)
        for idx,pi in enumerate(poisoned_indexes):
            X_copy[pi] = poisoned_vectors[idx]
        
        return PatientECGDataset(X_copy,Y_copy,metadata_copy)

def get_single_grad_vector(model):
    ret_val =[]
    for n,p in model.named_parameters():
        ret_val.append(p.grad.view(-1))
    
    ret_val = torch.concat(ret_val,axis=0)

    return ret_val

def set_grad_to_sign(model):
    for n,p in model.named_parameters():
        p.grad.sign_()
    return 

def free_grad_memory(model):
    for n,p in model.named_parameters():
        p.grad=None
    return 



TEST_ALL_DISEASE_SPLIT = {'214','223','106','105','102','104','213','230'}
TEST_ONLY_BELOW_50_AGE_SPLIT = {'208','106','113','115','212','230','111','203'}
TEST_WOMEN_OVER_75_AGE_SPLIT = {'102','108','121','207','220','222','228','232'}


from mitbih import RELEVANT_CLASS_IDS_TO_NAMES,RELEVANT_CLASS_NAMES_TO_IDS
from sklearn.metrics import classification_report

def get_report(predictions, correct_labels):

    print(classification_report(correct_labels,predictions,target_names = \
        [RELEVANT_CLASS_IDS_TO_NAMES[_] for _ in range(len(RELEVANT_CLASS_IDS_TO_NAMES))]))


def get_tsne_outputs(model,dataset,test_batch_size,device,coloring_scheme="label"):
    dataloader = DataLoader(dataset,batch_size=test_batch_size,shuffle=True)
    model.to(device)
    model.eval()
    labels=[]
    patients = []
    features = []
    with torch.no_grad():
        for x,y,m in dataloader:
            x = x.to(device = device,dtype = torch.float)
            output = model.forward(x,tsne_out=True).cpu().numpy()
            features.append(output)
            patients.extend(m)
            labels.extend(y.numpy().tolist())

    patients = np.array(patients)
    label_set = set(labels)
    features = np.concatenate(features,axis = 0)
    labels = np.array(labels)
    print(patients.shape,labels.shape,features.shape)
    features_by_class = {_:features[labels==_] for _ in label_set}
    feature_subset = np.concatenate([features_by_class[_][:100] for _ in label_set],axis = 0)
    patient_by_class = {_:patients[labels==_] for _ in label_set}
    patient_subset = np.concatenate([patient_by_class[_][:100] for _ in label_set],axis = 0)
    label_subset= np.concatenate([np.zeros((100,))+_ for _ in label_set],axis =0)
    ## GENERATING MULITPLE TSNE PLOTS WITH DIFFERENT PERPLEXITY 
    for perplexity in tqdm(range(20,51)):
        plt.figure()
        tsne = manifold.TSNE(n_components=2,init="random",perplexity=perplexity)
        print('Fitting TSNE')
        tsne_results = tsne.fit_transform(feature_subset)
        
        swarm_plot = sns.scatterplot(x=tsne_results[:,0],y=tsne_results[:,1],hue = label_subset,style = patient_subset)
        swarm_plot.set(title=f"Perplexity {perplexity}")
        plt.savefig(f"out/out_{perplexity}.png")


def get_device():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    print(f'Using device : {device}')
    return device

class DatasetWrapper(Dataset):
    def __init__(self,X,Y,metadata=None) -> None:
        super().__init__()
        self.X = X
        self.Y = Y
        self.metadata = metadata
    
    def __len__(self):
        return len(self.Y)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        if self.metadata is None:
            return (self.X[index],self.Y[index])
        else:
            return (self.X[index],self.Y[index],self.metadata[index])