import pandas as pd 
import numpy as np 
import os
from random import choices
from collections import Counter, defaultdict
from mitbih import OUTDIR,RELEVANT_CLASS_IDS_TO_NAMES,RELEVANT_CLASS_NAMES_TO_IDS
import pickle as pkl
import torch
from torch.utils.data import TensorDataset,WeightedRandomSampler,DataLoader,Dataset
from neural_model import CNN1D_Classifier
import torch
from tqdm import tqdm 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from utils import get_report,PatientECGDataset
### Reproducibility code
torch.random.manual_seed(0)
np.random.seed(0)
import random
random.seed(0)


def preprocess_data(ddir):
    with open(os.path.join(ddir,'mitbih.pkl'),'rb') as fin:
        x_dta,y_dta,u_dta = pkl.load(fin)
    
    patient_dta =defaultdict(list)
    patient_disease_dict = defaultdict(set)
    for x_,y_,u_ in zip(x_dta,y_dta,u_dta):
        patient_dta[y_[0]].append((x_,y_[2],u_))
        patient_disease_dict[y_[0]].add(y_[2])
    
    # for u,v in patient_disease_dict.items():
    #     if v.__len__()>1:
    #         print(u,v)
    #     else:
    #         print(u,"Normal")
    print("Number of Patients",len(patient_disease_dict))

    return patient_dta

def patient_level_data_to_train_test(patients_dta, patient_test_split,class_names_to_ids = None,
                                        train_batch_size=256,test_batch_size=512):
    
    X_train,X_test,Y_train_labels,Y_test_labels = [],[],[],[]
    train_ptas = []
    test_ptas = []

    for pat_idx,pat_dta in patients_dta.items():
        if pat_idx in patient_test_split:
            X_test = X_test + [_[0] for _ in pat_dta]
            Y_test_labels = Y_test_labels + [_[1] for _ in pat_dta]
            test_ptas = test_ptas  + [pat_idx for _ in pat_dta]
        
        else:
            X_train = X_train + [_[0] for _ in pat_dta]
            Y_train_labels = Y_train_labels + [_[1] for _ in pat_dta]    
            train_ptas = train_ptas  + [pat_idx for _ in pat_dta] 

    class_counts = Counter(Y_train_labels)
    print('Train Y',Counter(Y_train_labels))
    print('Test Y',Counter(Y_test_labels))

    if class_names_to_ids is None:
        class_names_to_ids = {k:idx for idx,k in enumerate(class_counts)}



    print("Class Counts",class_counts)
    class_sample_importance = {k:1/v for k,v in class_counts.items()}
    print("Class Importance weights",class_sample_importance)

    sample_importance_unnormalised = [class_sample_importance[_] for _ in Y_train_labels]
    sample_importance_normalised = np.array(sample_importance_unnormalised)/sum(sample_importance_unnormalised)

    Y_train = np.array([class_names_to_ids[_] for _ in Y_train_labels])
    Y_test = np.array([class_names_to_ids[_] for _ in Y_test_labels])
    
    X_train= np.array(X_train)
    X_test = np.array(X_test)
    print("Data shape",X_train.shape,Y_train.shape,X_test.shape,Y_test.shape)
    train_dataset = PatientECGDataset(X_train,Y_train,train_ptas)
    test_dataset = PatientECGDataset(X_test,Y_test,test_ptas)

    sampler = WeightedRandomSampler(sample_importance_normalised,len(sample_importance_normalised),replacement = True)

    train_dataloader = DataLoader(train_dataset,batch_size=train_batch_size,sampler=sampler)
    test_dataloader = DataLoader(test_dataset,batch_size=test_batch_size)

    print(f'Train Dataset Len ={len(train_dataset)},Train Batches={len(train_dataloader)}, \nTest Dataset Len={len(test_dataset)},Test Batches={len(test_dataloader)}')
    return train_dataloader,test_dataloader,train_dataset,test_dataset,sample_importance_normalised


if __name__=="__main__":
    patients_dta = preprocess_data(ddir = OUTDIR)
    train_dataloader,test_dataloader,train_dataset,test_dataset = \
        patient_level_data_to_train_test(patients_dta=patients_dta,patient_test_split=TEST_WOMEN_OVER_75_AGE_SPLIT,
                                    train_batch_size=256,test_batch_size=512)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    total_predictions, total_correct_labels = [],[]
    for i in range(5):
        model = CNN1D_Classifier(1,5)
        model = train_model(train_dataloader,model=model,n_epochs=100,lr=0.00001,device=device)
        predictions, correct_labels = evaluate_model(test_dataloader,model,device)
        total_predictions.extend(predictions)
        total_correct_labels.extend(correct_labels)
    
    get_report(total_predictions, total_correct_labels)
