""" Super simple neural net training code for a CLASSIFIER
    P(N | C, X)


    General details:
    Architecture:
    - Create embedding for C (simple nn.Embedding, since C is binary)
    - Create embedding for X (head of a resnet18)
    - Then concatenate and run through 4 layers of FC network

    Training:
    - Train with standard binary cross entropy 

"""

import torch
import torch.nn as nn 
from torchvision.models import resnet18
from torch.utils.data import DataLoader, random_split
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms

import argparse
from tqdm.auto import tqdm 
import os
from cxray.cxray_dataset import CXRayDataset, DiffusionScale, transback
from cfg.Scheduler import GradualWarmupScheduler
from sklearn.metrics import roc_auc_score

import cxray.calibration_tools as ct
# =========================================================
# =                   Model Definitions                   =
# =========================================================


class Classifier(nn.Module):
    def __init__(self, hdim=64):
        super().__init__()
        
        self.resnet = resnet18(pretrained=False)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
        self.resnet.fc = nn.Linear(512, hdim)

        self.embedding = nn.Embedding(2, hdim)

        self.joiner = nn.Sequential(
            nn.Linear(hdim * 2, hdim * 2),
            nn.ReLU(),
            nn.Linear(hdim * 2, hdim * 2),
            nn.ReLU(),
            nn.Linear(hdim * 2, hdim * 2),
            nn.ReLU(),
            nn.Linear(hdim * 2, 2),
            #nn.Sigmoid()
            )

    def forward(self, x, c):
        x_emb = self.resnet(x)
        #x_emb = torch.zeros_like(x_emb)
        c_emb = self.embedding(c)
        #c_emb = torch.zeros_like(c_emb)
        joined = torch.cat([x_emb, c_emb], dim=1)
        return self.joiner(joined)


# ========================================================
# =                Training Block                        =
# ========================================================



def train(params:argparse.Namespace):
    xform = transforms.Compose([transforms.ToTensor()])
    train_data = CXRayDataset(params.root_dir, split='train', transform=xform)

    dataloader = DataLoader(train_data, shuffle=True, batch_size=params.batch_size,
                            num_workers=params.num_workers, drop_last=True)

    val_data = CXRayDataset(params.root_dir, split='val', transform=xform)
    val_loader = DataLoader(val_data, shuffle=False, batch_size=params.batch_size,
                            num_workers=params.num_workers, drop_last=True)

    test_data = CXRayDataset(params.root_dir, split='test', transform=xform)
    test_loader = DataLoader(val_data, shuffle=True, batch_size=params.batch_size,
                            num_workers=params.num_workers, drop_last=False)

    device = torch.device('cuda', params.device)

    model = Classifier(params.hdim).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=params.lr, weight_decay=1e-4)
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer,
        T_max=params.epoch,
        eta_min=0,
        last_epoch=-1
        )

    warmUpScheduler = GradualWarmupScheduler(
        optimizer=optimizer,
        multiplier=params.multiplier,
        warm_epoch=params.epoch // 10,
        after_scheduler=cosineScheduler,
        last_epoch=0)

    last_eval = {}
    for epc in range(params.epoch):
        model.train()
        with tqdm(dataloader, dynamic_ncols=True) as tqdm_dataloader:
            for batch in tqdm_dataloader:
                optimizer.zero_grad()
                x = batch['X'].to(device)
                c = batch['C'].to(device)
                n = batch['N'].to(device)
                b = x.shape[0]

                loss = F.cross_entropy(model(x,c), n)
                loss.backward()
                optimizer.step()
                postfix = {'epoch': epc+1,
                           'loss': loss.item(),
                           'LR': optimizer.state_dict()['param_groups'][0]['lr']}
                postfix.update(last_eval)
                tqdm_dataloader.set_postfix(
                    ordered_dict=postfix)
            warmUpScheduler.step()
            last_eval = evaluate_model_cal(model, val_loader, device)


        if (epc+1) % params.interval == 0:
            os.makedirs(params.moddir, exist_ok=True)
            checkpoint = {
                'net': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': warmUpScheduler.state_dict()
            }
            if 't' in postfix:
                checkpoint['t'] = postfix['t']
            torch.save(checkpoint, os.path.join(params.moddir, f'ckpt_{epc+1}_checkpoint.pt'))


@torch.no_grad()
def evaluate_model_cal(model, val_loader, device):
    """ Computes:
        [loss, acc, ECE, MCE] pre-cal
        [loss, acc, ECE, MCE, T] post-cal
    """
    DATA_KEYS = ['X', 'C']
    LABEL_KEY = 'N'
    preds, labels_oneh, total_acc = ct.get_preds(model, val_loader, DATA_KEYS, LABEL_KEY, device, disable_tqdm=True)
    ece, mce = ct.get_metrics(preds, labels_oneh)

    output = {'acc': total_acc,
              'ECE': ece,
              'MCE': mce}


    temp = ct.learn_temperature(model, val_loader, DATA_KEYS, LABEL_KEY, device, disable_tqdm=True)[0]
    cal_preds, cal_labels, cal_acc = ct.get_preds(model, val_loader, DATA_KEYS, LABEL_KEY, device, disable_tqdm=True,
                                                  calibration_method=ct.T_scaling, cal_args=temp)
    cal_ece, cal_mce = ct.get_metrics(cal_preds, cal_labels)
    output['cacc'] = cal_acc
    output['cECE'] = cal_ece
    output['cMCE'] = cal_mce
    output['t'] = temp['temperature'].item()
    return output
                                                 



@torch.no_grad()
def evaluate_model(model, val_loader, device):
    # Returns accuracy pre/post calibration
    num_correct = 0 # Hard threshold of 0.5
    total = 0
    total_loss = 0.0
    total_correct = 0
    model.eval()
    Ns = []
    pred_Ns = []
    for batch in val_loader:
        X = batch['X'].to(device)
        C = batch['C'].to(device)
        N = batch['N'].to(device)

        pred_N = model(X, C)
        total += N.numel()
        total_loss += F.cross_entropy(pred_N, N, reduction='sum').item()
        total_correct += (pred_N.max(dim=1)[1] == N).sum().item()
        Ns.append(N)
        pred_Ns.append(pred_N)



    Ns = torch.cat(Ns).cpu().numpy()
    pred_Ns = torch.cat(pred_Ns).cpu().numpy()
    #num_correct = ((Ns.reshape(-1) > 0.5)== pred_Ns.astype(bool).reshape(-1)).sum()
    #roc_auc = roc_auc_score(Ns, pred_Ns)
    return {'val_loss': total_loss / total, 'val_acc': total_correct / total}






# ====================================================
# =           Main Block                             =
# ====================================================
def main():
    parser = argparse.ArgumentParser(description='Training for P(N|C,X)')
    parser.add_argument('--root_dir', type=str, required=True)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--epoch', type=int, default=101)
    parser.add_argument('--hdim', type=int, default=64)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--multiplier',type=float,default=2.5,help='multiplier for warmup')    
    parser.add_argument('--interval', type=int, default=25)
    parser.add_argument('--moddir', type=str, required=True)
    parser.add_argument('--device', type=int, required=True)


    args = parser.parse_args()
    train(args)

if __name__ == '__main__':
    main()





