""" Super simple neural net training code for a CLASSIFIER
    P(N | C, X)


    General details:
    Architecture:
    - Create embedding for C (simple nn.Embedding, since C is binary)
    - Create embedding for X (head of a resnet18)
    - Then concatenate and run through 4 layers of FC network

    Training:
    - Train with standard binary cross entropy 

"""

import torch
import torch.nn as nn 
from torchvision.models import resnet18
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms

import argparse
from tqdm.auto import tqdm 
import os
from cfg.dataloader_pickle import PickleDataset, load_data, transback
from cxray.cxray_dataset import CXRayDataset, DiffusionScale, transback
from cxray import cifar_resnets as cr
from cfg.Scheduler import GradualWarmupScheduler
from sklearn.metrics import roc_auc_score
# =========================================================
# =                   Model Definitions                   =
# =========================================================


class Classifier(nn.Module):
    def __init__(self, hdim=64):
        super().__init__()
        
        self.resnet = resnet18(pretrained=False)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
        self.resnet.fc = nn.Linear(512, hdim)
        
        """
        self.resnet = cr.resnet20()
        self.resnet.linear = nn.Linear(64, hdim)
        """
        self.embedding = nn.Embedding(2, hdim)

        self.joiner = nn.Sequential(
            nn.Linear(hdim * 2, hdim * 2),
            nn.ReLU(),
            nn.Linear(hdim * 2, hdim * 2),
            nn.ReLU(),
            nn.Linear(hdim * 2, hdim * 2),
            nn.ReLU(),
            nn.Linear(hdim * 2, 1),
            nn.Sigmoid()
            )

    def forward(self, x, c):
        x_emb = self.resnet(x)
        #x_emb = torch.zeros_like(x_emb)
        c_emb = self.embedding(c)
        #c_emb = torch.zeros_like(c_emb)
        joined = torch.cat([x_emb, c_emb], dim=1)
        return self.joiner(joined)


# ========================================================
# =                Training Block                        =
# ========================================================



def train(params:argparse.Namespace):
    xform = transforms.Compose([transforms.ToTensor(), DiffusionScale()])
    train_data = CXRayDataset(params.root_dir, train=True, transform=xform)
    dataloader = DataLoader(train_data, shuffle=True, batch_size=params.batch_size,
                            num_workers=params.num_workers, drop_last=True)

    val_data = CXRayDataset(params.root_dir, train=False, transform=xform)
    val_loader = DataLoader(val_data, shuffle=True, batch_size=params.batch_size,
                            num_workers=params.num_workers, drop_last=False)

    device = torch.device('cuda', params.device)

    model = Classifier(params.hdim).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=params.lr, weight_decay=1e-4)
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer,
        T_max=params.epoch,
        eta_min=0,
        last_epoch=-1
        )

    warmUpScheduler = GradualWarmupScheduler(
        optimizer=optimizer,
        multiplier=params.multiplier,
        warm_epoch=params.epoch // 10,
        after_scheduler=cosineScheduler,
        last_epoch=0)

    last_eval = {}
    for epc in range(params.epoch):
        model.train()
        with tqdm(dataloader, dynamic_ncols=True) as tqdm_dataloader:
            for batch in tqdm_dataloader:
                optimizer.zero_grad()
                x = batch['X'].to(device)
                c = batch['C'].to(device)
                n = batch['N'].to(device)
                b = x.shape[0]

                loss = F.binary_cross_entropy(model(x,c), n.view(-1, 1).float())
                loss.backward()
                optimizer.step()
                postfix = {'epoch': epc+1,
                           'loss': loss.item(),
                           'LR': optimizer.state_dict()['param_groups'][0]['lr']}
                postfix.update(last_eval)
                tqdm_dataloader.set_postfix(
                    ordered_dict=postfix)
            warmUpScheduler.step()
            last_eval = evaluate_model(model, val_loader, device)


        if (epc+1) % params.interval == 0:
            os.makedirs(params.moddir, exist_ok=True)
            checkpoint = {
                'net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': warmUpScheduler.state_dict()
            }
            torch.save(checkpoint, os.path.join(params.moddir, f'ckpt_{epc+1}_checkpoint.pt'))


@torch.no_grad()
def evaluate_model(model, val_loader, device):
    num_correct = 0 # Hard threshold of 0.5
    total = 0
    total_loss = 0.0
    model.eval()
    Ns = []
    pred_Ns = []
    for batch in val_loader:
        X = batch['X'].to(device)
        C = batch['C'].to(device)
        N = batch['N'].to(device)

        pred_N = model(X, C)
        total += N.numel()
        total_loss += F.binary_cross_entropy(pred_N, N.view(-1, 1).float(), reduction='sum').item()
        Ns.append(N)
        pred_Ns.append(pred_N)



    Ns = torch.cat(Ns).cpu().numpy()
    pred_Ns = torch.cat(pred_Ns).cpu().numpy()
    num_correct = ((Ns.reshape(-1) > 0.5)== pred_Ns.astype(bool).reshape(-1)).sum()
    roc_auc = roc_auc_score(Ns, pred_Ns)
    return {'AUC': roc_auc}






# ====================================================
# =           Main Block                             =
# ====================================================
def main():
    parser = argparse.ArgumentParser(description='Training for P(N|C,X)')
    parser.add_argument('--root_dir', type=str, required=True)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--epoch', type=int, default=101)
    parser.add_argument('--hdim', type=int, default=64)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--multiplier',type=float,default=2.5,help='multiplier for warmup')    
    parser.add_argument('--interval', type=int, default=25)
    parser.add_argument('--moddir', type=str, required=True)
    parser.add_argument('--device', type=int, required=True)


    args = parser.parse_args()
    train(args)

if __name__ == '__main__':
    main()





