import sys
sys.path.append('')

from myutils.extra_model import (
                                 GPNNMix4)
from myutils.mydataset import (
                               MixAns2,
                               MixAns3,
                               InfDataset,
                               MixLocal)
import json
from myutils.data_utils import (sample_appearance_indices,
                                )
import h5py
from myutils.config import *
from tqdm import tqdm
import argparse
import random
import logging
from torch import nn as nn
import os
import numpy as np
from myutils.data_utils import (
    add_weight_decay,
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    save_checkpoint,
    getTimeStamp,
    MyEvaluatorActionGenome,
)
os.environ['CUDA_VISIBLE_DEVICES']="0,1,2,3"
os.environ['OMP_NUM_THREADS'] = '1'

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from transformers.trainer_pt_utils import SequentialDistributedSampler,distributed_concat
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.nn.functional as F
import warnings
import torch.optim as optim
import torch

# 将所有警告转换为异常
# warnings.filterwarnings('error')



def load_model_dict2(model,p1=None,p2=None):
    if p1 is None:
        raise FileNotFoundError
    checkpoint2=torch.load(p1,'cpu')
    model_weight2=checkpoint2['model']
    model.load_state_dict(model_weight2,strict=True)
    if p2 is None:
        return model
    checkpoint=torch.load(p2,'cpu')
    model_weight=checkpoint['model']
    missing_keys, unexpected_keys=model.load_state_dict(model_weight,strict=False)
    # print(model_weight.keys())
    return model

def load_model_dict(p,model,s=True):
    # print(p)
    checkpoint=torch.load(p,'cpu')
    model_weight=checkpoint['model']
    model.load_state_dict(model_weight,strict=s)
    return model


# 16 x 16 grid
def test_oracle2(args,pretrain,p1):
    config=load_config()
    config.prompt.type=args.prompt
    config.max_epoch=args.epoch

    config.normtype=3
    
    write_path=''+str(args.dt)+'1.txt'
    write_ans_path=''+str(args.dt)+'1.txt'


    #     write_config(config,args,[os.path.join(loss_record_path_,time_stamp[8:12]+'config_args.txt')])
    flag=False
    if args.dt==2:
        flag=True
    
    device=args.device
    model=GPNNMix4(config,flag=pretrain,train_stage=args.stage).to(device)
    # load_model_dict
    model=load_model_dict(p1,model)
    min_index=1 if args.dt>1 else 2
    min_index=1

    model.eval()
    # common private total
    acc_list=[[ [ [] for k in range(j+1) ] for j in range(16) ] for i in range(4)]
    # names=['common','private','middle','total']
    # xx=[i for i in range(min_index,17)]
    print('dataset type',args.dt)
    f=open(write_path,'w')
    f2=open(write_ans_path,'w')
    with torch.no_grad():
        for test_i in range(min_index,17):
            # test_ii=test_i+1 if args.dt>1 else test_i
            test_ii=test_i+1
            for test_j in range(1,test_ii):
                # print('123',test_i,test_j)
                # test_dataset=TestMixAns('test',args.dt,test_i,sample_each_clip=16,train=False)
                # test_loader=DataLoader(test_dataset,batch_size=args.batchsize*4,num_workers=12)
                test_dataset=MixAns3('test',sample_each_clip=16,train=False,mapping_type=args.dt,test_i=test_i,test_j=test_j)

                test_loader=DataLoader(test_dataset,batch_size=args.batchsize*4,num_workers=12)

                evaluator = MyEvaluatorActionGenome(len(test_dataset),157)
                evaluator2 = MyEvaluatorActionGenome(len(test_dataset),158,flag)
                evaluator3 = MyEvaluatorActionGenome(len(test_dataset),158)
                # evaluator4 = MyEvaluatorActionGenome(len(test_dataset),157)

                evaluator.reset()
                evaluator2.reset()
                evaluator3.reset()
                # evaluator4.reset()

                for batch in test_loader:
                    frames,bbx,mask,label,cls_ids,cls_l,rel_l,private_label,common_label,token_tensor,mask_=batch
                    frames=frames.to(device)
                    bbx=bbx.to(device)
                    mask=mask.to(device)
                    cls_l=cls_l.to(device)
                    rel_l=rel_l.to(device)
                    cls_ids=cls_ids.to(device)
                    token_tensor=token_tensor.to(device)
                    private_label=private_label.to(device)
                    common_label=common_label.to(device)

                    label=label.to(device).squeeze()

                    c_ans,p_ans,t_ans=model(frames,cls_ids,rel_l,bbx,token_tensor,mask)
        
                    if len(label.shape)==1:
                        label.unsqueeze_(0)
                    if len(common_label):
                        common_label.unsqueeze_(0)
                    if len(private_label):
                        private_label.unsqueeze_(0)

                    evaluator.process(t_ans,label)
       
                    evaluator2.process(c_ans,common_label)
                    evaluator3.process(p_ans,private_label)
                    # evaluator4.process(m_ans,label)
                metrics = evaluator.evaluate()
                metrics2 = evaluator2.evaluate()
                metrics3 = evaluator3.evaluate()
                # metrics4 = evaluator4.evaluate()

                acc_list[0][test_i-1][test_j-1].append(metrics['map'])
                acc_list[1][test_i-1][test_j-1].append(metrics2['map'])
                acc_list[2][test_i-1][test_j-1].append(metrics3['map'])
                # acc_list[3][test_i-1][test_j-1].append(metrics4['map'])

        for i in range(min_index-1,16):
            t_ans_list=[]
            c_ans_list=[]
            p_ans_list=[]
            # m_ans_list=[]
            f.write('video:'+str(i)+'\n')
            kk=i+1
            for j in range(kk):
                t=np.mean(acc_list[0][i][j])
                t_ans_list.append(t)
                c=np.mean(acc_list[1][i][j])
                c_ans_list.append(c)
                p=np.mean(acc_list[2][i][j])
                p_ans_list.append(p)
                # m=np.mean(acc_list[3][i][j])
                # m_ans_list.append(m)
            print_acc=str(i+1)+' '+str(round(np.mean(c_ans_list)*100,2))+'%±'+str(round(np.std(c_ans_list)*100,1))+'%\t\n'
            # print_acc=str(i+1)+' '+str(round(np.mean(t_ans_list)*100,2))+'±'+str(round(np.std(t_ans_list)*100,1))+'\t'+\
            #             +'common:'+str(round(np.mean(c_ans_list)*100,2))+'%±'+str(round(np.std(c_ans_list)*100,1))+'%\t'\
                        # +'private:'+str(round(np.mean(p_ans_list)*100,2))+'%±'+str(round(np.std(p_ans_list)*100,1))+'%\t'+'\n'
                        # +'middle:'+str(round(np.mean(m_ans_list)*100,3))+'%±'+str(round(np.std(m_ans_list)*100,5))+'%\t\n'
            f.write(print_acc)
            f2.write(print_acc)
            f.write(str(acc_list[0][i])+'\n')
            f.write(str(acc_list[1][i])+'\n')
            f.write(str(acc_list[2][i])+'\n')
            # f.write(str(acc_list[3][i])+'\n')
            print(print_acc,end='')
    f.close()

    f2.close()

                
    # name='padding' if args.dt==2 else 'not padding'


    # draw_list_multi(acc_list,xx,name,names)

    # if pretrain:
    #     acc_str='t:'+str(round(metrics['map']*100,5))+'_c:'+str(round(metrics2['map']*100,5))+'_p:'+str(round(metrics3['map']*100,5))
    #     save_checkpoint(epoch+1,model,acc_str,optimizer,scheduler,time_stamp,'pretrain')
    #     print('saved')
    # else:
    #     save_checkpoint(epoch+1,model,round(metrics['map']*100,5),optimizer,scheduler,time_stamp,'train')

def set_seed(seed=3407):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed)

def test_iou(args,pretrain,p1):
    config=load_config()
    config.prompt.type=args.prompt
    config.max_epoch=args.epoch

    config.normtype=3
    
    time_stamp=getTimeStamp()
    loss_record_path=''
    t_path=time_stamp[0:8]
    loss_record_path_=os.path.join(loss_record_path,t_path)
    if not os.path.exists(loss_record_path_):
        os.makedirs(loss_record_path_)
    loss_record_path=os.path.join(loss_record_path_,time_stamp[8:12]+'ans.txt')

    device=args.device
    if args.model=='mix4':
        model=GPNNMix4(config,flag=pretrain,train_stage=args.stage).to(device)
    else:
        raise NotImplementedError
    model=load_model_dict(p1,model,True)
    model.set_visual(True)
    model.eval()
    thh1=[0.2,0.5,0.7]
    thh2=[0.2,0.5,0.7]
    TP_list=[[[0 for i in range(157)] for i in range(3)] for j in range(3)]
    FP_list=[[[0 for i in range(157)] for i in range(3)] for j in range(3)]
    with torch.no_grad():

        test_dataset=MixLocal('test',sample_each_clip=16,train=False,mapping_type=args.ds)
        test_loader=DataLoader(test_dataset,batch_size=args.batchsize*4,num_workers=12)
        for batch in tqdm(test_loader):
            frames,bbx,mask,label,cls_ids,cls_l,rel_l,private_label,common_label,token_tensor,mask_,frame_ans,frame_flag=batch
            frames=frames.to(device)
            bbx=bbx.to(device)
            cls_l=cls_l.to(device)
            rel_l=rel_l.to(device)
            cls_ids=cls_ids.to(device)
            token_tensor=token_tensor.to(device)
            private_label=private_label.to(device)
            # common_label=common_label.to(device)
            mask=mask.to(device)
            mask_=mask_.to(device)
            label=label.to(device).squeeze()
            # breakpoint()

            model.clear_visual()
            c_ans,p_ans,t_ans=model(frames,cls_ids,rel_l,bbx,token_tensor,mask)
            
            g_v,p_v,c_v,t_v=model.get_visual()
            
            # breakpoint()

            predictions=c_v[0][0].reshape(-1,16)
            frame_ans=frame_ans.reshape(-1,16)
            common_label=common_label.squeeze().numpy()
            frame_flag=frame_flag.squeeze().numpy()
            for i in range(3):
                for j in range(3):
                    iou_ans=(predictions>=thh1[i]).long()
                    try:
                        all_=iou_ans+frame_ans
                    except:
                        breakpoint()
                    base=torch.clamp(all_,max=1)
                    overlap=all_-base
                    IOU_=(overlap.sum(dim=-1)/base.sum(dim=-1)>thh2[j]).long().numpy()
                    for k in range(IOU_.shape[0]):
                        if frame_flag[k]==1:
                            if IOU_[k]==1:
                                TP_list[i][j][common_label[k]]+=1
                            else:
                                FP_list[i][j][common_label[k]]+=1
    AP_list=[[[]for i in range(3)]for j in range(3)]
    for i in range(157):
        for j in range(3):
            for k in range(3):
                if TP_list[j][k][i]+FP_list[j][k][i]==0:
                    AP_list[j][k].append(0)
                else:
                    AP_list[j][k].append(TP_list[j][k][i]/(TP_list[j][k][i]+FP_list[j][k][i]))
    sum_x=[0.,0.,0.]
    sum_y=[0.,0.,0.]
    for i in range(3):
        tp_,len_=0.,0.
        for j in range(3):
           sum_x[j]+=sum(AP_list[i][j])
           sum_y[j]+=len(AP_list[i][j])
           print(str(round(sum(AP_list[i][j])/len(AP_list[i][j])*100,2)),end=' ')
           tp_+=sum(AP_list[i][j])
           len_+=len(AP_list[i][j])
        print(str(round(tp_/len_*100,2)))
    for i in range(3):
        print(str(round(sum_x[i]/sum_y[i]*100,2)),end=' ')
    print('')

if __name__=='__main__':

        
    parser = argparse.ArgumentParser(description="Packs PIL images as HDF5.")

    parser.add_argument(
        "--device",
        type=str,
        default="cuda:2",
        help="gpu device",
    )
    parser.add_argument(
        "--sup",
        type=str,
        default="nothing",
        help="sth to say",
    )
    parser.add_argument(
        "--epoch",
        type=int,
        default=20,
        help="train epochs",
    )
    parser.add_argument(
        "--warmup",
        type=int,
        default=2,
        help="warmup epochs",
    )
    parser.add_argument(
        "--batchsize",
        type=int,
        default=32,
        help="batchsize",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=2e-4,
        help="learning rate",
    )
    parser.add_argument(
        "--decay",
        type=float,
        default=1e-3,
        help="learning rate",
    )
    parser.add_argument(
            "--clip_val",
            type=float,
            default=5.0,
            help="The gradient clipping value.",
        )
    parser.add_argument(
            "--wr",
            type=float,
            default=.1,
            help="warm up rate for continue",
        )
    parser.add_argument(
        "--model",
        type=str,
        default="mix4",
        help="model",
    )
    parser.add_argument(
        "--ds",
        type=int,
        default=1,
        help="dataset",
    )
    parser.add_argument(
        "--stage",
        type=int,
        default=4,
        help="train stage",
    )
    parser.add_argument(
        "--tp",
        type=int,
        default=0,
        help="train type 0:oracle 1:oracle continue 2:pure",
    )
    parser.add_argument(
        "--prompt",
        type=int,
        default=1,
        help="prompt type 0:smiple 1:gpfp",
    )
    parser.add_argument(
        "--loss",
        type=int,
        default=1,
        help="speration and reconstruction loss,0 no loss,1 loss",
    )
    parser.add_argument(
        "--p_index",
        type=int,
        default=0,
        help="continue path",
    )
    parser.add_argument(
        "--dt",
        type=int,
        default=0,
        help="dataset type",
    )
    set_seed(seed=3407)
    args = parser.parse_args()

    p1=[]
    if args.tp==1:
        test_oracle2(args,False,p1[args.p_index])
    elif args.tp==4:
        test_iou(args,False,p1[args.p_index])
    else:
        raise NotImplementedError
    # train_text2(args) 
    # train_rel_cls_no_mask(args,True)
    # train_rel_cls_sperate(args,True)
    # train_clip_stlt(args)
    