import numpy as np
import torch
import torch.nn as nn

import datasets as ds
import models
from torch.optim.lr_scheduler import CosineAnnealingLR

import argparse
import os
import torch.optim as optim
import random

from engine import *
from augmentations import *

import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import wandb

from utils import seed_worker


def seeding(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    return seed

def parse_option():
    parser = argparse.ArgumentParser('arguments for training')
    parser.add_argument('--dataset', type=str, default='stl10')
    parser.add_argument('--img-size', type=int, default=96)
    parser.add_argument('--data-root', type=str, default='DEFAULT')
    parser.add_argument('--optimizer', type=str, default='AdamW')
    parser.add_argument('--lr', type=float, default=0.0004)
    parser.add_argument('--scheduler', action='store_true', default=False)
    parser.add_argument('--warmup', type=int, default=20)
    parser.add_argument('--weight-decay', type=float, default=0.001)
    parser.add_argument('--mlp-proj', action='store_true', default=False)
    parser.add_argument('--epochs', type=int, default=2000)
    parser.add_argument('--batch-size', type=int, default=512)
    parser.add_argument('--num-saccades', type=int, default=5)
    parser.add_argument('--fovea-size', type=int, default=32)
    parser.add_argument('--two-emb', action='store_true', default=False)
    parser.add_argument('--conv-ijepa', action='store_true', default=False)
    parser.add_argument('--plus-projector', action='store_true', default=False)
    parser.add_argument('--ema', action='store_true', default=False)
    parser.add_argument('--ema-decay', type=float, default=0.996)
    parser.add_argument('--pred-hidden', type=int, default=1024)  ### hidden dimension for the predictor
    parser.add_argument('--num-heads', type=int, default=4)  ### number of heads for the transformer
    parser.add_argument('--num-enc-layers', type=int, default=3)  ### number of layers for the transformer
    parser.add_argument('--shuffle-saccades', type=int, default=0) ### 0 for no shuffling patches, 1 for shuffling patches
    parser.add_argument('--ior', type=int, default=1) ### 0 for no ior, 1 for ior
    parser.add_argument('--use-sal', type=int, default=1) ### 0 for no sal, 1 for sal
    parser.add_argument('--aug-patches', type=int, default=0) ### 0 for no augmentation, 1 for augmentation
    parser.add_argument("--cifar-resnet", action='store_true', default=False)
    parser.add_argument('--act-cond', type=int, default=1) ### 0 for no action conditioning, 1 for action conditioning
    parser.add_argument('--learn-act-emb', type=int, default=1) #### To learn an action embeddings or not
    parser.add_argument('--act-latentdim', type=int, default=2)
    parser.add_argument('--act-projdim', type=int, default=128)
   
    ### Miscellaneous
    parser.add_argument('--wandb', action='store_true', default=False)
    parser.add_argument('--offline-wandb', action='store_true', default=False)
    parser.add_argument('--run-id', type=str, default='use_default')
    parser.add_argument('--num-workers', type=int, default=8)
    parser.add_argument('--gpu-id', type=int, default=0)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--output-folder', type=str, default='DEFAULT')
    parser.add_argument('--is-eval', action='store_true', default=False)
    parser.add_argument('--eval-type', type=str, default='pos')
    parser.add_argument('--load-path', type=str, default='DEFAULT')
    parser.add_argument('--backbone', type=str, default="resnet18")
    parser.add_argument('--save-freq', type=int, default=10)
    parser.add_argument('--data-path-img', type=str, default='DEFAULT')
    parser.add_argument('--data-path-sal', type=str, default='DEFAULT')

    args = parser.parse_args()
    return args

    
def main(args):
    # Single GPU setup
    global_rank = 0
    device = torch.device("cuda", 0)
    gpu = 0
    if args.gpu_id >= 0:
        gpu = args.gpu_id
    torch.cuda.set_device(gpu)
    torch.cuda.empty_cache()
    print("Single GPU training On GPU:", torch.cuda.current_device())

    seed = args.seed
    seeding(seed)

    if args.run_id == 'use_default':
        run_id = f"seq-jepa-pls_{args.dataset}_{args.backbone}_pls"
    else:
        run_id = args.run_id
        
    load_path = args.load_path if args.load_path != 'DEFAULT' else None
    load_dict = torch.load(load_path,map_location=f"cuda:{gpu}") if load_path is not None else None
    
    output_folder = os.path.join(args.output_folder, run_id)
            
    if not os.path.exists(output_folder) and global_rank == 0:
        os.makedirs(output_folder, exist_ok=True)
        
        
    if args.dataset == 'stl10':
        sal_path = args.data_path_sal
        data_path = args.data_path_img
        
        unlabeled_dataset = ds.STL10_SalMap(data_path, sal_path, 'unlabeled', args.num_saccades,
                                            args.shuffle_saccades, args.use_sal, args.ior, args.aug_patches)
        train_dataset = ds.STL10_SalMap(data_path, sal_path, 'train', args.num_saccades,
                                            args.shuffle_saccades, args.use_sal, args.ior, False)
        test_dataset = ds.STL10_SalMap(data_path, sal_path, 'test', args.num_saccades,
                                            args.shuffle_saccades, args.use_sal, args.ior, False)
    elif args.dataset == 'imagenet':
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),  # Random crop with scaling and aspect ratio variation.
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
        ])
        transform_val = transforms.Compose([
            transforms.Resize(256),           # Resize the shorter side to 256 pixels.
            transforms.CenterCrop(224),       # Then center crop to 224x224.
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
        ])

        data_path_sal = args.data_path_sal
        data_path_img = args.data_path_img
        
        train_dataset = ds.Imagenet1k_Sal(data_path_img, data_path_sal, "train",
                                        transform=transform_train, full_img_size=args.img_size, 
                                        patch_size=args.fovea_size, num_patches=args.num_saccades)
        test_dataset = ds.Imagenet1k_Sal(data_path_img, data_path_sal, "val",
                                        transform=transform_val, full_img_size=args.img_size, 
                                        patch_size=args.fovea_size, num_patches=args.num_saccades)
    else:
        raise ValueError("Dataset not supported yet!")

    num_workers = args.num_workers
    batch_size = args.batch_size
    
    g = torch.Generator()
    g.manual_seed(args.seed)
    if args.dataset == 'stl10':
        unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, worker_init_fn=seed_worker, generator=g)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, worker_init_fn=seed_worker, generator=g)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, worker_init_fn=seed_worker, generator=g)
    
    if args.dataset == 'stl10':
        num_classes = 10
        n_channels = 3
        args.pos_dim = 128
    elif args.dataset == 'imagenet':
        num_classes = 1000
        n_channels = 3
        args.pos_dim = 128
    else:
        raise ValueError("Dataset not supported yet!")
    if args.act_cond > 0:
        args.act_cond = True
    else:
        args.act_cond = False
    if args.learn_act_emb > 0:
        if args.act_cond == False:
            raise ValueError("Learnable action embedding cannot be used without action conditioning!")
        args.learn_act_emb = True
    else:
        args.learn_act_emb = False

    latent_dim = 2
    args.act_latentdim = latent_dim
    
    kwargs = {"num_heads": args.num_heads, "n_channels": n_channels, "act_cond": args.act_cond, "pred_hidden": args.pred_hidden,
              "num_enc_layers": args.num_enc_layers, "num_classes": num_classes, "pos_dim": args.pos_dim, "backbone": args.backbone, "ema_decay": args.ema_decay,
              "act_projdim": args.act_projdim, "act_latentdim": args.act_latentdim, "learn_act_emb": args.learn_act_emb, "cifar_resnet": args.cifar_resnet}
    
    if args.conv_ijepa:
        model = models.Conv_IJEPA(args.fovea_size, args.img_size, args.ema, args.plus_projector, **kwargs)
        args.resout_eval = True
        print("Conv-I-JEPA model!")
        if args.plus_projector:
            print("Plus projector!")
    else:
        model = models.SeqJEPA_PLS(args.fovea_size, args.img_size, args.ema, **kwargs) 
    model = model.to(device)
    
    if args.is_eval == False:
        online_linprobe = nn.Sequential(nn.Linear(model.emb_dim, num_classes))
        online_equiprobe = nn.Sequential(nn.Linear(model.res_out_dim*2, latent_dim))

    if load_path is not None:
        model.load_state_dict(load_dict['model'])
        print("Model loaded!")
    if args.is_eval:
        if args.conv_ijepa == False:
            model.add_probes(max_val_len=args.num_saccades-1, non_relative_trans=False, num_obs=args.num_saccades-1)
        else:
            model.add_probes(non_relative_trans=False)
        model = model.to(device)
        # Define optimizers for PLS probe heads
        optimizer_pos_regressor = torch.optim.Adam(model.pos_regressor.parameters(), lr=args.lr, weight_decay=0)
        optimizer_res_class = torch.optim.Adam(model.res_classifier.parameters(), lr=args.lr, weight_decay=0)
        if args.conv_ijepa == False:
            optimizer_agg_class = [torch.optim.Adam(classifier.parameters(), lr=args.lr, weight_decay=0)
                for classifier in model.agg_classifier]
        else:
            optimizer_agg_class = None

    
    learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of learnable parameters: {learnable_params}")
            
    params_to_optimize = [param for param in model.parameters() if param.requires_grad]
    

    if args.is_eval:
        if args.optimizer == 'Adam':
            optimizer = optim.Adam(params_to_optimize, lr=args.lr, weight_decay=0.0)
            print("Adam Optimizer for linear probing...")
        elif args.optimizer == 'AdamW':
            optimizer = optim.AdamW(params_to_optimize, lr=args.lr, weight_decay=0.0)
            print("AdamW Optimizer for linear probing...")
        else:
            raise ValueError("Optimizer not supported!")
    else:
        if args.optimizer == 'Adam':
            optimizer = optim.Adam(params_to_optimize, lr=args.lr, weight_decay=args.weight_decay)
            print("Adam Optimizer for SSL training...")
        elif args.optimizer == 'AdamW':
            optimizer = optim.AdamW(params_to_optimize, lr=args.lr, weight_decay=args.weight_decay)
            print("AdamW Optimizer for SSL training...")
        else:
            raise ValueError("Optimizer not supported!")
    
    if args.is_eval == False:
        print("Online linear probes!")
        optimizer_linprobe = optim.Adam(online_linprobe.parameters())
        optimizer_equiprobe = optim.Adam(online_equiprobe.parameters())
        online_linprobe = online_linprobe.to(device)
        online_equiprobe = online_equiprobe.to(device)
        
    if args.scheduler and args.is_eval == False:
        eta_min = 1e-5
        lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=eta_min)
        print("Scheduler added!")
    if load_path is not None and args.is_eval == False:
        optimizer.load_state_dict(load_dict['optimizer'])
        print("Optimizer loaded!")
    config_dict = vars(args)
    print(config_dict)
    id_ = "no-id"
    
    if global_rank == 0:
        if args.wandb == True:
            if args.offline_wandb:
                os.environ['WANDB_MODE'] = 'offline'
                os.environ["WANDB__SERVICE_WAIT"] = "300"
            else:
                os.environ["WANDB__SERVICE_WAIT"] = "300"
                wandb.login()
            id_ = wandb.util.generate_id()
            run_id = f"wandbid-{id_}_" + run_id
            wandb_logger = wandb.init(name=run_id, id=run_id, config=config_dict)
            print("Wandb initialized!")
 
    if load_path is not None and args.is_eval == False:
        ep_tr = int(load_dict['epoch'])
        min_loss = load_dict['min_loss']
        print("Resuming training from epoch:", ep_tr) 
    else:
        ep_tr = 0
        min_loss = 1e9
    epochs = args.epochs
    if args.is_eval == False:
        print("Training...")
        for epoch in range(ep_tr, epochs):
            if args.warmup > 0 and args.is_eval == False:
                if epoch < args.warmup:
                    initial_lr = 1e-5
                    lr = initial_lr + (args.lr - initial_lr) * (epoch / args.warmup)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                else:
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = args.lr
            model, optimizer, result = train_one_epoch_stl10_pls(model, unlabeled_loader, optimizer, args.num_saccades, args.img_size,
                                                        args.fovea_size, device, online_linprobe, optimizer_linprobe,
                                                        online_equiprobe, optimizer_equiprobe, args.ema, args.ema_decay, epoch, epochs, conv_jepa=args.conv_ijepa,
                                                        train_loader=train_loader, test_loader=test_loader, dataset=args.dataset)
            if args.scheduler:
                if args.warmup and epoch < args.warmup:
                    pass
                else:
                    lr_scheduler.step()

            epoch_loss = result["ep_loss"]
            ep_time = result["ep_time"]
            min_loss = min(min_loss, epoch_loss)
            
            if global_rank == 0:
                online_linacc = result["online_linacc_test"]
                online_r2 = result["online_r2_test"]
                
                print("Epoch {}/{}, Loss: {:.6f}, min_loss: {:.6f}, ep_time:{:.2f}, online_linacc: {:.4f}, online_r2: {:.4f}".format(epoch+1, args.epochs, epoch_loss, min_loss, ep_time, online_linacc, online_r2))
                if args.wandb == True:
                    log_data = {"ep_loss": epoch_loss, "ep_time": ep_time}
                    log_data["online_linacc_test"] = online_linacc
                    log_data["online_r2_test"] = online_r2
                    log_data["online_r2_train"] = result["online_r2_train"]
                    log_data["online_linacc_train"] = result["online_linacc_train"]
                    log_data["online_r2_loss_test"] = result["online_r2_loss_test"]
                    log_data["online_r2_loss_train"] = result["online_r2_loss_train"]
                    log_data["online_linloss_train"] = result["online_linloss_train"]
                    log_data["online_linloss_test"] = result["online_linloss_test"]

                    log_data["online_linacc"] = online_linacc
                    log_data["online_r2"] = online_r2
                    wandb_logger.log(log_data, step=epoch)
                    print("Wandb logged!")
                if (epoch+1) % args.save_freq == 0:
                    save_state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'min_loss': min_loss, 'epoch': epoch+1, 'run_id': run_id}
                    if args.scheduler:
                        save_state['lr_scheduler'] = lr_scheduler.state_dict()
                    save_path = os.path.join(output_folder, f'ckpt_wandb-{id_}_epoch_{epoch+1}.pth')
                    torch.save(save_state, save_path)
    else:
        for epoch in range(ep_tr, epochs):
            print("Evaluating... training eval head...")
            agg_eval = True if args.conv_ijepa == False else None
            model, results = val_all_one_epoch_pls(
                model, device, train_loader, test_loader, args.img_size, args.fovea_size,
                optimizer_pos_regressor, optimizer_res_class, agg_eval, optimizer_agg_class
            )

            if global_rank == 0:
                formatted_results = ", ".join(
                    f"{key}: {value:.6f}" if isinstance(value, (int, float)) else f"{key}: {value}"
                    for key, value in results.items()
                )
                print(f"Epoch {epoch+1}/{args.epochs}, {formatted_results}")
                if args.wandb:
                    log_data = {}
                    for key, value in results.items():
                        if isinstance(value, list):
                            # Log each element separately in Wandb
                            for j, v in enumerate(value):
                                log_data[f"{key}_{j}"] = v
                        else:
                            log_data[key] = value
                    
                    wandb_logger.log(log_data, step=epoch)
    print("Done!")



if __name__ == '__main__':
    args = parse_option()
    main(args)