"""
For each of W1,X,Y creates a classifier to estimate:
- digit
- color 
- thickness
"""

import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
import argparse
from tqdm.auto import tqdm
import os
from torch.utils.data import DataLoader
from cfg.dataloader_pickle import PickleDataset


# =================================================
# =              Color Identifier                 =
# =================================================


def get_color(x): 
    # Takes in tensor of [B, 3, 32, 32] images
    # outputs longtensor of their color codes

    if x.min() < 0.0:
        x = x * 0.5 + 0.5 # images in range [0,1]
        x.clamp(0, 1)
    norms = (x.view(-1, 3, 1024).norm(dim=2) > 1.0).cpu().numpy()
    COLOR_CODES = {(True, False, False): 0, # Red
                   (False, True, False): 1, # Green
                   (False, False, True): 2, # Blue
                   (True, True, False): 3, # Yellow
                   (True, False, True): 4, # Magenta
                   (False, True, True): 5, # Cyan
                   (True, True, True): 6, # White 
                   (False, False, False): 7} # Black

    return torch.LongTensor([COLOR_CODES[tuple(_)] for _ in norms])




# =======================================================
# =           Basic MNIST training code stuff           =
# =======================================================

class ConvNet(nn.Module):
    def __init__(self, num_classes):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(8 * 8 * 32, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out  

def train(model, train_loader, val_loader, datakey, labelkey,
          epochs, device):

    optimizer = optim.Adam(model.parameters())
    iterator = tqdm(range(epochs))
    model = model.to(device)

    for epoch in iterator:
        model = model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            x = batch[datakey].to(device)
            y = batch[labelkey].to(device)
            pred = model(x)
            loss = F.cross_entropy(pred, y)
            loss.backward()
            optimizer.step()
        acc, loss = evaluate(model, val_loader, datakey, labelkey, device)
        iterator.set_postfix(ordered_dict={'acc': acc, 'loss': loss})
    return model

@torch.no_grad()
def evaluate(model, val_loader, datakey, labelkey, device):
    model = model.eval()
    total = 0
    total_loss = 0
    total_correct = 0
    for batch in val_loader:
        x = batch[datakey].to(device)
        y = batch[labelkey].to(device)
        pred = model(x)

        total += y.numel()
        total_loss += (F.cross_entropy(pred, y) * y.numel()).item()
        total_correct += (pred.max(dim=1)[1] == y).sum().item()
    return total_correct / total, total_loss / total


def save_models(W1_digit, W1_thickness, W1_color,
                X_digit, X_thickness, X_color,
                Y_digit, Y_thickness, Y_color,
                save_dir, save_name):
    checkpoint = {'W1_digit': W1_digit.state_dict(),
                  'W1_thickness': W1_thickness.state_dict(),
                  'W1_color': W1_color.state_dict(),
                  'X_digit': X_digit.state_dict(),
                  'X_thickness': X_thickness.state_dict(),
                  'X_color': X_color.state_dict(),
                  'Y_digit': Y_digit.state_dict(),
                  'Y_thickness': Y_thickness.state_dict(),
                  'Y_color': Y_color.state_dict()}

    os.makedirs(save_dir, exist_ok=True)
    torch.save(checkpoint, os.path.join(save_dir, save_name))
    

def load_models(save_dir, save_name):
    """ Loads all the models """
    checkpoint = torch.load(os.path.join(save_dir, save_name), map_location='cpu')

    output = {}
    for k, v in checkpoint.items():
        if k.endswith('digit'):
            output[k] = ConvNet(10)
            output[k].load_state_dict(v)
        elif k.endswith('thickness'):
            output[k] = ConvNet(3)
            output[k].load_state_dict(v)
        elif k.endswith('color'):
            output[k] = ConvNet(8)
            output[k].load_state_dict(v)


    return output



# =======================================
# =           MAIN BLOCK                =
# =======================================



def main():
    parser = argparse.ArgumentParser(description='Trains all models for classification of thickness/color')
    parser.add_argument('--train_pkl', type=str, required=True)
    parser.add_argument('--val_pkl', type=str, required=True)
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=512)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--save_dir', type=str, required=True)
    parser.add_argument('--save_name', type=str, required=True)

    args = parser.parse_args()

    device = 'cuda:%s' % args.device

    trainset = PickleDataset(args.train_pkl)
    train_loader = DataLoader(trainset, num_workers=args.num_workers, batch_size=args.batch_size, 
                              shuffle=True, drop_last=True)
    valset = PickleDataset(args.val_pkl)
    val_loader = DataLoader(valset, num_workers=args.num_workers, batch_size=args.batch_size,
                            shuffle=False, drop_last=False)

    save_kwargs = {'save_dir': args.save_dir, 'save_name': args.save_name}
    for pfx in ['W1', 'X', 'Y']:
        for sfx, ncls in [('digit',10), ('thickness',3), ('color', 8)]:
            name = '%s_%s' % (pfx, sfx)
            print("Training %s..." % (name))
            model = ConvNet(ncls).to(device)
            model = train(model, train_loader, val_loader, 
                          datakey=pfx, labelkey=name,
                          epochs=args.epoch, device=device)
            save_kwargs[name] = model.cpu()

    save_models(**save_kwargs)

if __name__ == '__main__':
    main()



    





