import os
import numpy as np
import torch
import wandb
import argparse
import torch.nn as nn
from absl import app
from tqdm import tqdm
from datetime import datetime
import math
from PIL import Image
import matplotlib.pyplot as plt
import cv2

from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader, random_split
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from transformers import get_cosine_schedule_with_warmup, get_scheduler
from torch.optim.lr_scheduler import LinearLR, ChainedScheduler

from idm.dataset import DataSet
from idm.enhanced_regressor import EnhancedRegressor
from idm.idm import *
from idm.loss import AdaptiveLoss, NTKInspiredAdaptiveLoss, WeightedSmoothL1Loss, WeightedL2Loss
from idm.preprocessor import DinoPreprocessor, segment_robot_arms


def parse_args():
    parser = argparse.ArgumentParser(description="Train IDM")
    parser.add_argument("--load_from", type=str, default=None, help="Load from path")
    parser.add_argument("--wandb_mode", type=str, default="online", help="Wandb mode")
    parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--use_transform", action="store_true", default=False, help="Use transform")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size per GPU")
    parser.add_argument("--eval_batch_size", type=int, default=128, help="Batch size per GPU")
    parser.add_argument("--num_workers", type=int, default=16, help="Number of data loading workers")
    parser.add_argument("--prefetch_factor", type=int, default=4, help="Number of batches to prefetch")
    parser.add_argument("--dataset_path", type=str, default=None, help="Path of the dataset")
    parser.add_argument("--num_iterations", type=int, default=150000, help="Number of iterations")
    parser.add_argument("--eval_interval", type=int, default=2000, help="Intervals of evaluation. ")
    parser.add_argument("--run_name", type=str, default=datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), help="Run name")
    parser.add_argument("--save_dir", type=str, default="output", help="Save dir")
    parser.add_argument("--ratio_eval", type=float, default=0.05, help="Ratio of data for validation, but eval_dataset_size is at most 10000")
    parser.add_argument("--model-name", type=str, default="direction_aware", help="Choose a suitable model.")
    parser.add_argument("--dinov2_name", type=str, default="facebook/dinov2-base", help="DINOv2 name.")
    parser.add_argument("--freeze_dinov2", action="store_true", default=False, help="Freeze DINOv2")
    parser.add_argument("--lr_scheduler", type=str, default="cosine", choices=["constant", "cosine"], help="Learning rate scheduler type")
    parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay")
    parser.add_argument("--test_dataset_path", type=str, default=None, help="Path of the test dataset")
    parser.add_argument("--eval_only", action="store_true", default=True, help="Only run evaluation on val and test sets")
    parser.add_argument("--use_normalization", action="store_true", default=False, help="Use mean/std normalization")
    parser.add_argument("--load_mp4", action="store_true", default=True, help="load the data in mp4 format to save memory")
    parser.add_argument("--learning_dim", type=str, default=None, help="Comma-separated list of dimensions to learn (e.g. '0,1,2,3'). If None, learn all dimensions.")
    parser.add_argument("--model_name", type=str, default="direction_aware", help="Model name")
    parser.add_argument("--shuffle", type=bool, default=False, help="shuffle input")
    args = parser.parse_args()
    return args


def collate_fn(batch, preprocessor: DinoPreprocessor):
    # batch is a list of tuples (pos, image), pos is [B, 14], \del image is [B, 3, 518, 518] tensor \del
    # image is a pil image: [H, W, C]
    # concatenate all pos and images from list to tensor
    pos, images = zip(*batch)
    # preprocess images
    images = preprocessor.process_batch(images)  # [B, 3, 518, 518]
    pos = torch.stack(pos)
    # flip images and pos
    images, pos = preprocessor.flip_images_pos_batch(images, pos)
    # Stack all tensors at once
    return images, pos


def get_data_generator(dataloader):
    while True:
        for data in dataloader:
            yield data


def save_model(accelerator: Accelerator, net: torch.nn.Module, optimizer: torch.optim.Optimizer, step, save_path):
    accelerator.wait_for_everyone()
    save_dir = os.path.dirname(save_path)
    if accelerator.is_main_process:
        try:
            os.makedirs(save_dir, exist_ok=True)
            if not os.access(save_dir, os.W_OK):
                print(f"Warning: No write permission for directory {save_dir}")
                return
                
            state_dict = {
                # "model_state_dict": accelerator.get_state_dict(net),
                "model_state_dict": accelerator.unwrap_model(net).state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "step": step
            }
            torch.save(state_dict, save_path)
        except Exception as e:
            print(f"Error saving model: {str(e)}")
    accelerator.wait_for_everyone()


def is_close(pos, output):
    limit = torch.tensor([0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05]).to(pos.device)
    # gripper:
    limit[6] = 0.5
    limit[13] = 0.5
    # Handle both single samples and batches
    if pos.dim() == 1:
        return torch.all(torch.abs(pos - output) < limit)
    else:
        return torch.all(torch.abs(pos - output) < limit, dim=1)

def eval(accelerator: Accelerator, net: torch.nn.Module, dataloader: DataLoader, loss_fn, step, use_normalization, mode='val'):
    accelerator.wait_for_everyone()
    net.eval()
    os.makedirs("heatmap", exist_ok=True)
    with torch.set_grad_enabled(True):
        
        for images, pos in tqdm(dataloader, disable=not accelerator.is_main_process):
            pos = accelerator.gather(pos)
            images = accelerator.gather(images)
            
            for i in range(images.shape[0]):
                input_pic = images[i].unsqueeze(0)
                
                def fgsm_attack(image, loss_fn, model, target_joints):
                    image.requires_grad = True
                    outputs = model(image)
                    loss = loss_fn(outputs, target_joints)
                    model.zero_grad()
                    loss.backward()
                    
                    # get attention
                    grad_cam = image.grad.data.abs().sum(dim=1, keepdim=True)
                    return grad_cam
                
                mask = fgsm_attack(input_pic, loss_fn, net, pos[i].unsqueeze(0))
                
                mean = mask.mean()
                std = mask.std()
                mask = (mask - mean) / std
                
                mask = (mask-mask.min()) / (mask.max()-mask.min())
                mask = mask.cpu().numpy()
                attention_array = (mask[0][0] * 255).astype(np.uint8)
                
                heatmap = cv2.applyColorMap(attention_array, cv2.COLORMAP_JET)
                input_pic = input_pic[0].permute(1, 2, 0).cpu().detach().numpy() * 255
                plt.figure()
                plt.imshow(heatmap, cmap='viridis')
                plt.axis('off')
                plt.savefig(f'heatmap/heatmap_{i}.pdf', format='pdf', bbox_inches='tight')
                plt.close()
            break
            

    net.train()
    accelerator.wait_for_everyone()


def main(args):
    accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
    num_gpus = torch.cuda.device_count()
    save_dir = os.path.join(args.save_dir, args.run_name)
    
    # Initialize wandb only if not in eval mode
    if accelerator.is_main_process and not args.eval_only:
        os.makedirs(save_dir, exist_ok=True)
        wandb.init(project=f"IDM_{args.model_name}", mode=args.wandb_mode, config=args.__dict__, name=args.run_name)
    
    if accelerator.is_main_process:
        print(f"{args.__dict__}")

    # Initialize preprocessor
    preprocessor = DinoPreprocessor(args)
    
    # load dataset
    dataset = DataSet(args, dataset_path=args.dataset_path, disable_pbar=not accelerator.is_main_process, preprocessor=preprocessor)
    test_dataset = DataSet(args, dataset_path=args.test_dataset_path, disable_pbar=not accelerator.is_main_process, type="test", preprocessor=preprocessor)
    dataset_size = len(dataset)
    val_dataset_size = min(int(args.ratio_eval * dataset_size), 10000)
    train_dataset_size = dataset_size - val_dataset_size
    train_dataset, val_dataset = random_split(dataset, [train_dataset_size, val_dataset_size])
    if accelerator.is_main_process:
        print('train_dataset_size', train_dataset_size, 'val_dataset_size', val_dataset_size, 'test_dataset_size', len(test_dataset))
    
    use_collate_fn = collate_fn  # or collate_fn_with_split
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                  num_workers=args.num_workers, pin_memory=True, collate_fn=lambda batch: use_collate_fn(batch, preprocessor), drop_last=True, prefetch_factor=args.prefetch_factor)
    val_dataloader = DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False,
                                num_workers=args.num_workers, pin_memory=True, collate_fn=lambda batch: use_collate_fn(batch, preprocessor), drop_last=False, prefetch_factor=args.prefetch_factor)
    test_dataloader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=args.shuffle,
                                 num_workers=args.num_workers, pin_memory=True, collate_fn=lambda batch: use_collate_fn(batch, preprocessor), drop_last=False, prefetch_factor=args.prefetch_factor)

    # Parse learning_dim parameter
    if args.learning_dim is not None:
        args.learning_dim = [int(dim) for dim in args.learning_dim.split(',')]
    else:
        args.learning_dim = list(range(14))  # Default to all dimensions

    if accelerator.is_main_process:
        print(f"Learning dimensions: {args.learning_dim}")


    # model init
    net = IDM(model_name=args.model_name, dinov2_name=args.dinov2_name, freeze_dinov2=args.freeze_dinov2, output_dim=14)
    
    # Set different learning rates for DINO and DPT models
    # Collect all parameters except dino_model and dpt_model
    base_params = []
    dino_params = []

    # Iterate through all named parameters
    for name, param in net.named_parameters():
        if 'dino_model' in name:
            # name is like: model.region_models.3.0.dino_model.encoder.layer.8.attention.attention.key.bias
            dino_params.append(param)
        else:
            base_params.append(param)

    param_groups = [
        {"params": dino_params, "lr": args.learning_rate * 0.1},
        {"params": base_params, "lr": args.learning_rate}
    ]

    optimizer = AdamW(param_groups, weight_decay=args.weight_decay)
    net.train()
    # loss_fn = WeightedSmoothL1Loss(beta=1.0)
    loss_fn = WeightedSmoothL1Loss(beta=0.1, learning_dim=args.learning_dim)
    # target_precision = torch.tensor([0.0125] * 6 + [0.25] + [0.0125] * 6 + [0.25])
    # loss_fn = AdaptiveLoss(target_precision, train_mean, train_std)
    # loss_fn = NTKInspiredAdaptiveLoss(target_precision, data_range, train_mean, train_std)
    # smooth l1 loss
    # loss_fn = nn.SmoothL1Loss()

    scheduler = None

    if not args.load_from or not os.path.isfile(args.load_from):
        raise ValueError("Must specify --load_from with a valid model path when using --eval_only")
    else:
        # Remove weights_only=True parameter and add error handling
        try:
            loaded_dict = torch.load(args.load_from)
            net.load_state_dict(loaded_dict["model_state_dict"])
            if accelerator.is_main_process:
                print(f"Loaded model from {args.load_from}")
        except Exception as e:
            raise RuntimeError(f"Failed to load checkpoint from {args.load_from}: {str(e)}")

    net, optimizer, train_dataloader, val_dataloader, test_dataloader = accelerator.prepare(
        net, optimizer, train_dataloader, val_dataloader, test_dataloader)
    net.normalize = accelerator.unwrap_model(net).normalize
    if scheduler is not None:
        scheduler = accelerator.prepare(scheduler)

    preprocessor.use_transform = False
    eval(accelerator, net, test_dataloader, loss_fn, 0, args.use_normalization, mode='test')
    preprocessor.use_transform = args.use_transform


if __name__ == "__main__":
    app.run(main(parse_args()))
