import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
import time
import os
import cv2
from tqdm import tqdm
import argparse
import matplotlib.pyplot as plt
import warnings
from show import *
from per_segment_anything import sam_model_registry, SamPredictor
from pathlib import Path

from datasets.dataset import FSSDataset
from utils.vis import Visualizer

def get_arguments():
    
    parser = argparse.ArgumentParser()

    # Dataset parameters
    parser.add_argument('--datapath', type=str, default='./data')
    parser.add_argument('--benchmark', type=str, default='coco', choices=['fss', 'coco', 'pascal', 'lvis', 'paco_part', 'pascal_part'])
    parser.add_argument('--bsz', type=int, default=1)
    parser.add_argument('--nworker', type=int, default=0)
    parser.add_argument('--fold', type=int, default=0)
    parser.add_argument('--nshot', type=int, default=1)
    parser.add_argument('--img-size', type=int, default=518)
    parser.add_argument('--use_original_imgsize', action='store_true')
    
    parser.add_argument('--outdir', type=str, default='persam_f')
    parser.add_argument('--ckpt', type=str, default='./sam_vit_h_4b8939.pth')
    parser.add_argument('--sam_type', type=str, default='vit_h')
    # parser.add_argument('--all', type=bool, default=False)
    parser.add_argument('--nested', type=int, default=10)
    parser.add_argument('--eta', type=float, default=1e-3) # 1e-4, 1e-3
    parser.add_argument('--gamma', type=float , default = 1e-1) # 1e-1, 1e-2
    parser.add_argument('--f', type=str, default='KL')
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument('--tracking', type=int, default=1)
    # parser.add_argument('--start', type=int, default=0)
    # parser.add_argument('--end', type=int, default=None)

    parser.add_argument('--lr', type=float, default=1e-3) 
    parser.add_argument('--train_epoch', type=int, default=1000)
    parser.add_argument('--log_epoch', type=int, default=200)
    parser.add_argument('--ref_idx', type=str, default='1')
    

    args = parser.parse_args()
    args.outdir = f'persam_f/{args.eta}/' + args.benchmark + f'{args.fold}'
    args.pred_path = args.outdir

    return args


def main():

    args = get_arguments()
    # print("Args:", args)
    torch.manual_seed(42)

    # Dataset/Data loader initialization
    img_size = None if args.use_original_imgsize else args.img_size    
    FSSDataset.initialize(img_size=img_size, 
                          datapath=args.datapath, 
                          use_original_imgsize=args.use_original_imgsize)

    dataloader_test = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'test', args.nshot)

    if not os.path.exists('./outputs_iclr_x/'):
        os.mkdir('./outputs_iclr_x/')
        
    # create save_dir.
    output_path = './outputs_iclr_ver2/' + args.outdir # ./outputs/persam/args.benchmarks/
    os.makedirs(output_path, exist_ok= True)
    for t in range(args.nested): 
        os.makedirs(os.path.join(output_path, f'{t}'), exist_ok= True)
    
    p_bar = tqdm(dataloader_test, total= len(dataloader_test))
    
    mIoU, mAcc, mDice = [0 for _ in range(args.nested)], [0 for _ in range(args.nested)], [0 for _ in range(args.nested)]
    count = 0
    
    intersection_meter = [AverageMeter() for t in range(args.nested)]  
    union_meter = [AverageMeter() for t in range(args.nested)] 
    target_meter =[AverageMeter() for t in range(args.nested)]
    total_meter =[AverageMeter() for t in range(args.nested)]

    dfs = []
    for t in range(args.nested):
        dfs.append(pd.DataFrame(columns= ['Obj_name', 'IoU', 'Acc', 'Dice', 'Points'], index= [row for row in range(len(dataloader_test))]))

    # visualize
    Visualizer.initialize(args.nested, args)
    
    device = 'cuda:{}'.format(args.device) if torch.cuda.is_available() else "cpu"
    # print("======> Load SAM" )
    if args.sam_type == 'vit_h':
        sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
        sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device)
    elif args.sam_type == 'vit_t':
        sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt'
        # device = "cuda" if torch.cuda.is_available() else "cpu"
        sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device)
        sam.eval()

    for batch_idx, batch in enumerate(p_bar):
        count += 1 
        class_id = batch['class_id'].item()
        obj_name = f'{batch_idx}_class-{class_id}'

        sam.eval()

        try:
            out = persam_f(args, obj_name, batch, output_path, sam=sam) # list

            for t in range(args.nested):
                    
                test_image = out[t]['test_image']
                final_mask = out[t]['final_mask']
                topk_xy, topk_label= out[t]['topk_xy'], out[t]['topk_label']
                
                intersection, union, target, total \
                    = out[t]['area_intersection'], out[t]['area_union'], out[t]['area_target'], out[t]['area_total']
                
                intersection_meter[t].update(intersection), union_meter[t].update(union), target_meter[t].update(target), total_meter[t].update(total)
                
                iou_class = intersection / (union + 1e-10)
                accuracy_class = intersection_meter[t].sum / (target_meter[t].sum + 1e-10)
                dice_class = intersection_meter[t].sum * 2 / (total_meter[t].sum + 1e-10) 
                
                mIoU[t] += iou_class
                mAcc[t] += accuracy_class
                mDice[t] += dice_class      
                dfs[t].iloc[batch_idx, :] = [obj_name , iou_class, accuracy_class, dice_class, str(topk_xy)]
                

                if batch_idx % 100 == 0:
                    print(f"[{batch_idx}/{len(dataloader_test)}] [{t}] mIoU: {mIoU[t]/count:.4f}, mAcc: {mAcc[t]/count:.4f}, mDice: {mDice[t]/count:.4f} ")                  
                elif batch_idx == len(dataloader_test) - 1: 
                    print(f"[{batch_idx}/{len(dataloader_test)}] [{t}] mIoU: {mIoU[t]/count:.4f}, mAcc: {mAcc[t]/count:.4f}, mDice: {mDice[t]/count:.4f} ")
                
                if args.tracking:
                    # Save masks
                    output_path_t = os.path.join(output_path, f'{t}') # ./outputs/persam/args.benchmark/{t}
                    Visualizer.visualize_prediction(batch['support_imgs'].squeeze(0), batch['support_masks'].squeeze(0),
                                                    batch['query_img'].squeeze(0), batch['query_mask'].squeeze(0),
                                                    final_mask, batch['class_id'], batch_idx,
                                                    iou= iou_class, output_path_t= output_path_t, topk_xy=topk_xy)
                                        
        except: 
            print('Error occured.')
            for t in range(args.nested):
                intersection, union, target, total = 0, 0, 0, 0
                
                intersection_meter[t].update(intersection), union_meter[t].update(union), target_meter[t].update(target), total_meter[t].update(total)
                
                iou_class = 0
                accuracy_class = 0
                dice_class = 0
                
                mIoU[t] += iou_class
                mAcc[t] += accuracy_class
                mDice[t] += dice_class       
                
                print(f"[{batch_idx}/{len(dataloader_test)}] [{t}] mIoU: {mIoU[t]/count:.4f}, mAcc: {mAcc[t]/count:.4f}, mDice: {mDice[t]/count:.4f} ")
                dfs[t].iloc[batch_idx, :] = [obj_name , iou_class, accuracy_class, dice_class, '']        

        print(dfs[t].head())
    # save the results
    for t in range(args.nested):
        mean = dfs[t].iloc[:, 1:-1].mean(axis=0)
        mean['Obj_name'] = 'mean'
        mean['Points'] = ''
        
        dfs[t].loc[dfs[t].shape[0]] = mean
        dfs[t] = dfs[t].set_index('Obj_name')
        # df.index = [f'{t}' for t in range(len(dataloader_test))] + ['mean']
        print(dfs[t].tail(1))
        dfs[t].to_csv(os.path.join(output_path, f'perfs{t}.csv'))


def persam_f(args, obj_name, batch, output_path, sam= None):
    total_time = 0
    regraf_time = 0
    start_time=time.time()
    device = 'cuda:{}'.format(args.device) if torch.cuda.is_available() else "cpu"
    T = args.nested
    eta = args.eta
    f = args.f
    gamma = args.gamma 
    noise_factor = np.sqrt(gamma) 

    # Load images and masks
    ref_image = batch['support_imgs'].squeeze().permute(1,2,0).numpy() # H, W, 3
    ref_image = (ref_image*255).astype(np.uint8)
    ref_mask = batch['support_masks'].squeeze(0).permute(1,2,0).repeat(1,1,3).numpy() # H, W, 3
    ref_mask = 255* ref_mask.astype(np.uint8)
    gt_mask = batch['support_masks'].squeeze().flatten()[None, ...].to(device) # 1, HxW
    
    for name, param in sam.named_parameters():
        param.requires_grad = False
    predictor = SamPredictor(sam)

    # Image features encoding
    ref_mask = predictor.set_image(ref_image, ref_mask)
    ref_feat = predictor.features.squeeze().permute(1, 2, 0)

    ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
    ref_mask = ref_mask.squeeze()[0]

    # Target feature extraction
    target_feat = ref_feat[ref_mask > 0]
    target_feat_mean = target_feat.mean(0)
    target_feat_max = torch.max(target_feat, dim=0)[0]

    target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0)

    # Cosine similarity
    h, w, C = ref_feat.shape
    target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True)
    ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)
    ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w)
    sim = target_feat @ ref_feat

    sim = sim.reshape(1, 1, h, w)
    sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
    sim = predictor.model.postprocess_masks(
                    sim,
                    input_size=predictor.input_size,
                    original_size=predictor.original_size).squeeze()

    # Positive location prior
    topk_xy, topk_label = point_selection(sim, device,topk=1)

    # Learnable mask weights
    mask_weights = Mask_Weights().to(device)
    mask_weights.train()
    
    optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch)
    
    #########
    # PerSAM-f's fine tuning process (short fine-tuning)
    #########
    for train_idx in range(args.train_epoch):

        # Run the decoder
        masks, scores, logits, logits_high = predictor.predict(
            point_coords=topk_xy,
            point_labels=topk_label,
            multimask_output=True)
        logits_high = logits_high.flatten(1)

        # Weighted sum three-scale masks
        weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
        logits_high = logits_high * weights
        logits_high = logits_high.sum(0).unsqueeze(0)

        dice_loss = calculate_dice_loss(logits_high, gt_mask)
        focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask)
        loss = dice_loss + focal_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()


    mask_weights.eval()
    weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
    weights_np = weights.detach()

    #########
    # First prediction (PerSAM-F's original process for extract prompts)
    #########
 
    test_image = batch['query_img'].squeeze().permute(1,2,0).numpy() # H, W, 3
    test_image = (test_image*255).astype(np.uint8)

    # Image feature encoding
    predictor.set_image(test_image)
    test_feat = predictor.features.squeeze()
    
    # Cosine similarity
    C, h, w = test_feat.shape
    test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
    test_feat = test_feat.reshape(C, h * w)
    sim = target_feat @ test_feat

    sim = sim.reshape(1, 1, h, w)
    sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
    sim = predictor.model.postprocess_masks(
                    sim,
                    input_size=predictor.input_size,
                    original_size=predictor.original_size).squeeze()
    
    
    # Positive location prior
    topk_xy, topk_label = point_selection(sim,device ,topk=1)

    
    
    #########
    # Refinement loop for calculating and applying graident flow
    #########
    grad_map = torch.zeros_like(predictor.features)
    perfs = []
    for t in range(T):  # Refinement loop (if T==1, it outputs same results with PerSAM-F)
        # First-step prediction
        with torch.no_grad():
            masks, scores, logits, logits_high = predictor.predict(
                        point_coords=topk_xy,
                        point_labels=topk_label,
                        multimask_output=True)

            # Weighted sum three-scale masks
            logits_high = logits_high * weights.unsqueeze(-1)
            logit_high = logits_high.sum(0)
            mask = (logit_high > 0).detach().cpu().numpy()

            logits = logits * weights_np[..., None]
            logit = logits.sum(0)

            # Cascaded Post-refinement-1
            y, x = np.nonzero(mask)
            x_min = x.min()
            x_max = x.max()
            y_min = y.min()
            y_max = y.max()
            input_box = np.array([x_min, y_min, x_max, y_max])
            masks, scores, logits, _ = predictor.predict(
                point_coords=topk_xy,
                point_labels=topk_label,
                box=input_box[None, :],
                mask_input=logit[None, :, :],
                multimask_output=True)
            best_idx = np.argmax(scores)
        # Cascaded Post-refinement-2
        y, x = np.nonzero(masks[best_idx])
        x_min = x.min()
        x_max = x.max()
        y_min = y.min()
        y_max = y.max()
        input_box = np.array([x_min, y_min, x_max, y_max])
        
        predictor.features.requires_grad_(True)
        masks, scores, logits, _ = predictor.predict(
            point_coords=topk_xy,
            point_labels=topk_label,
            box=input_box[None, :],
            mask_input=logits[best_idx: best_idx + 1, :, :],
            multimask_output=True)
        best_idx = np.argmax(scores)
        #########
        # Calculate gradient flow
        #########
        gradient_feature = torch.autograd.grad(outputs=logits[best_idx: best_idx + 1, :, :], inputs=predictor.features,
                    grad_outputs=torch.ones_like(logits[best_idx: best_idx + 1, :, :]),
                    create_graph=True, retain_graph=True )[0]
        
        if f == 'KL':
            s = torch.ones_like(predictor.features.detach())
        else:
            raise ValueError('not supported')
        
        v = s.data * torch.clamp(gradient_feature.data, min=-0.1, max=0.1)
        if t != 0:
            grad_map = grad_map + eta * v +\
                np.sqrt(2*eta) * noise_factor * torch.randn_like(predictor.features)
        else:
            grad_map = grad_map
        
        #########
        # Sample new prompts based on updated query embedding
        #########
        with torch.no_grad(): 
            test_feat = predictor.features.squeeze(0) + grad_map.squeeze(0)

            # Cosine similarity
            C, h, w = test_feat.shape
            test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
            test_feat = test_feat.reshape(C, h * w)
            sim = target_feat @ test_feat

            sim = sim.reshape(1, 1, h, w)
            sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
            sim = predictor.model.postprocess_masks(
                            sim,
                            input_size=predictor.input_size,
                            original_size=predictor.original_size).squeeze()

        # Positive location prior
        topk_xy, topk_label = point_selection(sim,device ,topk=1,t=t)

        y, x = np.nonzero(masks[best_idx])
        x_min = x.min()
        x_max = x.max()
        y_min = y.min()
        y_max = y.max()
        input_box = np.array([x_min, y_min, x_max, y_max])

        final_mask = masks[best_idx]

        true_mask = batch['query_mask'].squeeze().detach().cpu().numpy()
        area_intersection, area_union, area_target, area_total = intersectionAndUnion(final_mask, true_mask)
        
        #########
        # Save every iteration's information 
        #########
        perf = {    
                    'topk_xy': topk_xy, 
                    'topk_label': topk_label,
                    'test_image': test_image,
                    'true_mask': true_mask, 
                    'final_mask': final_mask, # = predicted masks
                    'area_intersection': area_intersection, 
                    'area_union': area_union, 
                    'area_target': area_target, 
                    'area_total': area_total
                }
        perfs.append(perf)

    return perfs
#########
# util function
#########
class Mask_Weights(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3)


def point_selection(mask_sim, device, topk=1,t=0):
    # Top-1 point selection
    w, h = mask_sim.shape
    top1_val, top1_idx = mask_sim.view(-1).max(0)
    top1_x = top1_idx // h
    top1_y = top1_idx % h
    top1_xy = torch.tensor([[top1_y.item(), top1_x.item()]]).to(device)
    if t >= 0:
        quantiles = [
            (0.99999, 2)  
            ,(0.9999, 2)  
        ]
        quantiles = []
    else:
        quantiles = []
    
    selected_indices = set([top1_idx.item()])
    selected_xy = []

    torch.manual_seed(42)
    for q, num_points in quantiles:
        threshold = torch.quantile(mask_sim, q)
        mask = mask_sim >= threshold

        for idx in selected_indices:
            x = idx // h
            y = idx % h
            mask[x, y] = False

        candidates = torch.nonzero(mask, as_tuple=False)

        if candidates.shape[0] == 0:
            print(f"No candidates found for quantile {q}.")
            continue

        num_to_select = min(num_points, candidates.shape[0])
        sampled_indices = torch.randperm(candidates.shape[0])[:num_to_select]
        sampled_points = candidates[sampled_indices]

        for point in sampled_points:
            y, x = point.tolist()
            all_idx = x * h + y
            if all_idx not in selected_indices:
                selected_xy.append([y, x])
                selected_indices.add(all_idx)
    

    if selected_xy:
        selected_xy = torch.tensor(selected_xy, device=device)  
        all_xy = torch.cat((top1_xy, selected_xy), dim=0)  # Shape: [4, 2]
    else:
        all_xy = top1_xy 

    all_labels = np.ones(all_xy.shape[0], dtype=int)


    all_xy = all_xy.cpu().numpy()
    
    return all_xy, all_labels


def visualize_bounding_box(input_box, new_plot= False):
    
    x_tl, y_tl , x_br, y_br = input_box
    
    if new_plot:
        plt.figure()    

    # visualize bounding box
    plt.plot([x_tl, x_br, x_br, x_tl, x_tl],
             [y_tl, y_tl, y_br, y_br, y_tl],
             color='green', label='Bounding Box')

    plt.grid(True)

def calculate_dice_loss(inputs, targets, num_masks = 1):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    """
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1)
    numerator = 2 * (inputs * targets).sum(-1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum() / num_masks


def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor
    """
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_masks


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def intersectionAndUnion(output, target):
    assert (output.ndim in [1, 2, 3])
    assert output.shape == target.shape
    output = output.reshape(output.size).copy()
    target = target.reshape(target.size)
    
    area_total = output.sum()+target.sum()

    area_intersection = np.logical_and(output, target).sum()
    area_union = np.logical_or(output, target).sum()
    area_target = target.sum()
    
    return area_intersection, area_union, area_target, area_total


if __name__ == '__main__':
    main()
