import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from Dataloader_funcs.utils import RASampler

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
from transformers import get_cosine_schedule_with_warmup
from torchvision.transforms import v2
from imagenet_dataset import ILSVRC2012, ImageNet
import timm
import tqdm
from custom_vit import CustomPE_ViT
import Positional_Embeddings.AxialRoPE
from model_registry import MODEL_REGISTRY
import wandb
import torch_optimizer as optimizers
import random

# Multi-node: torchrun --standalone --nnodes=1 --nproc-per-node=4 train.py
PE = 'Default'
dset = 'ImageNet v3'
# version = 'imagenet235px'
version = 'imagenet'
num_epochs = 200
smoothing = 0.1
n_heads = 6
dropout=0.
n = 3
layer_scale = 1.
embed_dim=(64//n)*n
ViT_Params = {
                'patch_size' : 16,
                'img_size' : 224,
                'in_chans': 3,
                'num_classes': 1000,
                'embed_dim': embed_dim*n_heads, # 60 is divisible by both 4 and 3 
                'depth': 12,
                'num_heads': n_heads,
                'global_pool' : '',
                'class_token' : False,
                'drop_rate' : dropout,
                'drop_path_rate' : dropout,
                'init_values' : layer_scale,
                }
if PE != 'Default':
    ViT_Params['pos_embed'] = 'none'
attn_args = {
                'attn_drop' : dropout,
                'proj_drop' : dropout,
}
save_path = './trained_models/' + dset + '_' + PE

lr = 4e-3
w_d = .05 # Weight Decay
batch_size = 256

def build_model():
    kwargs = MODEL_REGISTRY[PE]
    kwargs['flash_att'] = True
    kwargs['ViT_kwargs'] = ViT_Params
    kwargs['attn_kwargs'] = attn_args
    model = CustomPE_ViT(**kwargs)
    return model

def ddp_setup():

    # os.environ['MASTER_ADDR'] = 'localhost'
    # os.environ['MASTER_PORT'] = '12355'
    # torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    init_process_group(backend="nccl")

def cleanup():
   destroy_process_group()

def split_list(lst, k):
    """ GPT4 generated """
    n = len(lst)
    return [lst[ (i*n)//k : ((i+1)*n)//k ] for i in range(k)]

if __name__ == '__main__':
    print('setting up')

    ddp_setup()
    print('setup complete') 
    world_size = int(os.environ["WORLD_SIZE"])
    print('world', world_size)

    local_rank = int(os.environ["LOCAL_RANK"])
    model = build_model().to(local_rank)
    model = DDP(model, device_ids=[local_rank], output_device=local_rank)

    chunks = split_list(list(range(26)), world_size)

    dataset = ImageNet('train+', version=version, image_size=224, chunks=chunks[local_rank])
    val_dataset = ImageNet('val+', version=version, image_size=224, chunks=chunks[local_rank])
    if local_rank==0:
        wandb.init(
        project=f"QuatRo-{dset}",
        name=f"{PE}",
        group=f"{PE}",
        config={
                'epochs': num_epochs,
                'batch_size': batch_size,
                'learning_rate': lr,
                'model': 'ViT-B',
                'dataset': dset
                }

        )
    # sampler = RASampler(dataset, num_replicas=world_size, rank=local_rank, shuffle=True)
    # sampler_val = torch.utils.data.DistributedSampler(
    #             val_dataset, num_replicas=world_size, rank=local_rank, shuffle=False)
    # loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=12, pin_memory=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=sampler_val, num_workers=12, pin_memory=True)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=12, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True)



    scaler = torch.amp.GradScaler()
    best_acc = 0.0

    opt = optimizers.lamb.Lamb(model.parameters(), lr=lr, weight_decay=w_d)

    scheduler = get_cosine_schedule_with_warmup(
                    opt,
                    num_warmup_steps=5,
                    num_training_steps=num_epochs,
                )
    cm = v2.CutMix(num_classes=1000, alpha=1.)
    mu = v2.MixUp(num_classes=1000, alpha=.8)
    # criterion = torch.nn.BCELoss()
    criterion = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)

    cmmu = [cm,mu]
    scheduler.step()
    print('PE:', PE)
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        model.train()
        for sample in tqdm.tqdm(loader):
            opt.zero_grad()
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                labels = sample[1].to(local_rank)
                inputs = sample[0].to(local_rank)
                # cm_or_mu = random.choice(cmmu)
                # inputs, labels_oh = cm_or_mu(inputs, labels)
                # labels_oh = labels
                # labels_oh = torch.nn.functional.one_hot(labels, num_classes=1000).float()
                pred = model(inputs)
                # loss = F.binary_cross_entropy_with_logits(pred, labels_oh)
                loss = criterion(pred, labels)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            running_loss+= loss
            _, predicted = torch.max(pred, 1)
            correct += (predicted == labels).sum()
            total += sample[1].size(0)
        epoch_loss = running_loss.item() / total
        epoch_accuracy = 100.0 * correct.item() / total
        print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {epoch_loss:.4f} | Accuracy: {epoch_accuracy:.2f}%")
        if local_rank==0:
            wandb.log({
                'train_loss': epoch_loss,
                'train_accuracy': epoch_accuracy,
                'lr': scheduler.get_last_lr()[0]
            }, step=epoch+1)
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        model.eval()
        with torch.no_grad():
            for sample in tqdm.tqdm(val_loader):
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    s = sample[1].to(local_rank)
                    pred = model(sample[0].to(local_rank))
                    loss = torch.nn.functional.cross_entropy(pred, s)
                val_loss += loss.item()
                _, predicted = torch.max(pred, 1)
                val_correct += (predicted == s).sum()
                val_total += sample[1].size(0)

        val_accuracy = 100.0 * val_correct.item() / val_total
        avg_val_loss = val_loss / val_total

        print(f"Validation Accuracy: {val_accuracy:.2f}% | Val Loss: {avg_val_loss:.4f}")
        if local_rank==0:
            wandb.log({
                'valid_loss': avg_val_loss,
                'val_accuracy': val_accuracy
            }, step=epoch+1)
            if val_accuracy > best_acc:
                best_acc = val_accuracy
                torch.save(model.state_dict(), save_path)
                wandb.save(save_path)
        scheduler.step()
    if local_rank==0:
        wandb.finish()
    cleanup()