from random import seed, shuffle
import warnings
from sklearn import metrics
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn
import torch
import cv2
import pandas as pd
from sklearn.model_selection import KFold
import numpy as np
import wandb
import pathlib
import os
import time
from torchvision.models import resnet50
import matplotlib.pyplot as plt
from datetime import datetime
import transformers
warnings.filterwarnings('ignore')


def seed_everything(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_transforms():
    return transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize(MEAN,STD)])


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

class UltraMNISTDataset(Dataset):
    def __init__(self,df,root_dir,transforms=None):
        self.df = df
        self.root_dir = root_dir
        self.transforms = transforms
    def __len__(self):
        return len(self.df)
    def __getitem__(self,index):
        image_id = self.df.iloc[index].image_id
        digit_sum = self.df.iloc[index].digit_sum - START_SUM
        image = cv2.imread(f"{self.root_dir}/{image_id}.jpeg")
        if self.transforms is not None:
            image = self.transforms(image)
        return image, torch.tensor(digit_sum)

def get_train_val_dataset(print_lengths=True):
    transforms_dataset = get_transforms()
    train_df = pd.read_csv(TRAIN_CSV_PATH)
    train_df = train_df.sample(frac=1).reset_index(drop=True)
    print(train_df.digit_sum.unique())
    val_df = pd.read_csv(VAL_CSV_PATH)
    val_df = val_df.sample(frac=1).reset_index(drop=True)
    if SANITY_CHECK:
        train_df = train_df[:SANITY_DATA_LEN]
        val_df = val_df[:SANITY_DATA_LEN]
    if print_lengths:
        print(f"Train set length: {len(train_df)}, validation set length: {len(val_df)}")
    train_dataset = UltraMNISTDataset(train_df,TRAIN_ROOT_DIR,transforms_dataset)
    validation_dataset = UltraMNISTDataset(val_df,VAL_ROOT_DIR,transforms_dataset)
    return train_dataset, validation_dataset

def get_metrics(predictions,actual,isTensor=False):
    if isTensor:
        p = predictions.detach().cpu().numpy()
        a = actual.detach().cpu().numpy()
    else:
        p = predictions
        a = actual
    kappa_score = metrics.cohen_kappa_score(a, p, labels=None, weights= 'quadratic', sample_weight=None)
    accuracy = metrics.accuracy_score(y_pred=p,y_true=a)
    return {
        "kappa":  kappa_score,
        "accuracy": accuracy
    }

def get_output_shape(model, image_dim):
    return model(torch.rand(*(image_dim))).data.shape

class Backbone(nn.Module):
    def __init__(self):
        super(Backbone,self).__init__()
        self.encoder = resnet50(pretrained=True)
        self.encoder.fc = nn.Linear(2048,NUM_CLASSES)
    def forward(self,x):
        return self.encoder(x)


if __name__ == "__main__":

    DEVICE_ID = 3 ######################################
    os.environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE_ID)
    now = datetime.now() 
    date_time = now.strftime("%d_%m_%Y__%H_%M")
    MAIN_RUN = True ######################################
    CONINUE_FROM_LAST = False ######################################

    MONITOR_WANDB = True ######################################
    SAVE_MODELS =  MONITOR_WANDB
    BASELINE = True
    SANITY_CHECK = False
    EPOCHS = 300
    LEARNING_RATE = 1e-3
    ACCELARATOR = f'cuda:0' if torch.cuda.is_available() else 'cpu'
    SCALE_FACTOR = 1 ######################################
    IMAGE_SIZE = int(SCALE_FACTOR * 512)
    BATCH_SIZE = 2 ######################################
    FEATURE = '' ######################################
    MEMORY = 16 ######################################
    MODEL_LOAD_DIR = '' ######################################
    RUN_NAME = f'' ######################################

    START_SUM = 0
    NUM_CLASSES = 10
    SEED = 42
    LEARNING_RATE_BACKBONE = LEARNING_RATE
    LEARNING_RATE_HEAD = LEARNING_RATE
    WARMUP_EPOCHS = 2
    NUM_WORKERS = 4
    TRAIN_ROOT_DIR = f""    
    VAL_ROOT_DIR = f""
    TRAIN_CSV_PATH = ''
    VAL_CSV_PATH = ''
    MEAN = [0.1307,0.1307,0.1307]
    STD = [0.3081,0.3081,0.3081]
    SANITY_DATA_LEN = None
    MODEL_SAVE_DIR = f""
    EXPERIMENT = ''
    DECAY_FACTOR = 2

    if SAVE_MODELS:
        os.makedirs(MODEL_SAVE_DIR,exist_ok=True)
   
    if MONITOR_WANDB:
        run = wandb.init(project=EXPERIMENT, entity="", reinit=True)
        wandb.run.name = RUN_NAME
        wandb.run.save()
    
 
    seed_everything(SEED)
 
    
    train_dataset, val_dataset = get_train_val_dataset()
    train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=NUM_WORKERS)
    validation_loader = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=NUM_WORKERS)
 
    print(f"Length of train loader: {len(train_loader)},Validation loader: {(len(validation_loader))}")
 
    model1 = Backbone()
    model1.to(ACCELARATOR)
    for param in model1.parameters():
        param.requires_grad = True
 
    print(ACCELARATOR)
    print(RUN_NAME)
 
 
    print(f"Baseline model:")
    
    criterion = nn.CrossEntropyLoss()
    lrs = {
        'head': LEARNING_RATE_HEAD,
        'backbone': LEARNING_RATE_BACKBONE
    }
    parameters = [{'params': model1.parameters(),
                    'lr': lrs['backbone']},
                    ]
    optimizer = optim.Adam(parameters)
    steps_per_epoch = len(train_dataset)//(BATCH_SIZE)

    if len(train_dataset)%BATCH_SIZE!=0:
        steps_per_epoch+=1
    scheduler = transformers.get_linear_schedule_with_warmup(optimizer,WARMUP_EPOCHS*steps_per_epoch,DECAY_FACTOR*EPOCHS*steps_per_epoch)
    
    if CONINUE_FROM_LAST:
        checkpoint = torch.load(f"{MODEL_LOAD_DIR}/best_val_accuracy.pt")
        start_epoch = checkpoint['epoch']
        print(f"Model already trained for {start_epoch} epochs on 512 size images.")
        print(model1.load_state_dict(checkpoint['model1_weights']))
    
    best_validation_loss = float('inf')
    best_validation_accuracy = 0
    best_validation_metric = -float('inf')
 
    for epoch in range(EPOCHS):
        print("="*31)
        print(f"{'-'*10} Epoch {epoch+1}/{EPOCHS} {'-'*10}")

        running_loss_train = 0.0
        running_loss_val = 0.0
        train_correct = 0
        val_correct  = 0
        num_train = 0
        num_val = 0
       
        train_predictions = np.array([])
        train_labels = np.array([])

        val_predictions = np.array([])
        val_labels = np.array([])
 
        model1.train()
        print("Train Loop!")
        for images,labels in tqdm(train_loader):
            images = images.to(ACCELARATOR)
            labels = labels.to(ACCELARATOR)
            batch_size = labels.shape[0]
            num_train += labels.shape[0]
            optimizer.zero_grad()
            outputs = model1(images)
            # if torch.isnan(outputs).any():
            #     print("output has nan")
            _,preds = torch.max(outputs,1)
            train_correct += (preds == labels).sum().item()
            correct = (preds == labels).sum().item()

            train_metrics_step = get_metrics(preds,labels,True)
            train_predictions = np.concatenate((train_predictions,preds.detach().cpu().numpy()))
            train_labels = np.concatenate((train_labels,labels.detach().cpu().numpy()))

            loss = criterion(outputs,labels)
            l = loss.item()
            running_loss_train += loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()
            lr = get_lr(optimizer)
            if MONITOR_WANDB:
                wandb.log({'lr':lr,"train_loss_step":l/batch_size,'epoch':epoch,'train_accuracy_step_metric':train_metrics_step['accuracy'],'train_kappa_step_metric':train_metrics_step['kappa']})
        train_metrics = get_metrics(train_predictions,train_labels)
        print(f"Train Loss: {running_loss_train/num_train} Train Accuracy: {train_correct/num_train}")
        print(f"Train Accuracy Metric: {train_metrics['accuracy']} Train Kappa Metric: {train_metrics['kappa']}")    
    
 
        # Evaluation Loop!
        val_accr = 0.0
        val_lossr = 0.0
        if (epoch+1)%1 == 0:
 
            model1.eval()
       
            with torch.no_grad():
                print("Validation Loop!")
                for images,labels in tqdm(validation_loader):
                    images = images.to(ACCELARATOR)
                    labels = labels.to(ACCELARATOR)
                    batch_size = labels.shape[0]

                    outputs = model1(images)
                    # if torch.isnan(outputs).any():
                    #     print("L1 has nan")
                    num_val += labels.shape[0]
                    _,preds = torch.max(outputs,1)
                    val_correct += (preds == labels).sum().item()
                    correct = (preds == labels).sum().item()

                    val_metrics_step = get_metrics(preds,labels,True)
                    val_predictions = np.concatenate((val_predictions,preds.detach().cpu().numpy()))
                    val_labels = np.concatenate((val_labels,labels.detach().cpu().numpy()))


                    loss = criterion(outputs,labels)
                    l = loss.item()
                    running_loss_val += loss.item()
                    if MONITOR_WANDB:
                        wandb.log({'lr':lr,"val_loss_step":l/batch_size,"epoch":epoch,'val_accuracy_step_metric':val_metrics_step['accuracy'],'val_kappa_step_metric':val_metrics_step['kappa']})
                
                val_metrics = get_metrics(val_predictions,val_labels)
                print(f"Validation Loss: {running_loss_val/num_val} Validation Accuracy: {val_correct/num_val}")
                print(f"Val Accuracy Metric: {val_metrics['accuracy']} Val Kappa Metric: {val_metrics['kappa']}")    

                if (running_loss_val/num_val) < best_validation_loss:
                    best_validation_loss = running_loss_val/num_val
                    if SAVE_MODELS:
                        torch.save({
                        'model1_weights': model1.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                        'scheduler_state': scheduler.state_dict(),
                        'epoch' : epoch+1,
                        }, f"{MODEL_SAVE_DIR}/best_val_loss.pt")
            
                if val_metrics['accuracy'] > best_validation_accuracy:
                    best_validation_accuracy = val_metrics['accuracy']
                    if SAVE_MODELS:
                        torch.save({
                        'model1_weights': model1.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                        'scheduler_state': scheduler.state_dict(),
                        'epoch' : epoch+1,
                        }, f"{MODEL_SAVE_DIR}/best_val_accuracy.pt")
                
                if val_metrics['kappa'] > best_validation_metric:
                    best_validation_metric = val_metrics['kappa']
                    if SAVE_MODELS:
                        torch.save({
                        'model1_weights': model1.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                        'scheduler_state': scheduler.state_dict(),
                        'epoch' : epoch+1,
                        }, f"{MODEL_SAVE_DIR}/best_val_metric.pt")
                
        if MONITOR_WANDB:
             wandb.log({"training_loss": running_loss_train/num_train,  
             "validation_loss": running_loss_val/num_val, 
             'training_accuracy_metric': train_metrics['accuracy'],
             'training_kappa_metric': train_metrics['kappa'],
             'validation_accuracy_metric': val_metrics['accuracy'],
             'validation_kappa_metrics': val_metrics['kappa'],
             'epoch':epoch,
             'best_loss':best_validation_loss,
             'best_accuracy':best_validation_accuracy,
             'best_metric': best_validation_metric})
        
        if SAVE_MODELS:
            torch.save({
                    'model1_weights': model1.state_dict(),
                    'optimizer_state': optimizer.state_dict(),
                    'scheduler_state': scheduler.state_dict(),
                    'epoch' : epoch+1,
                    }, f"{MODEL_SAVE_DIR}/last_epoch.pt")