import torch 
import argparse
import numpy as np  
from utils  import get_network, set_dataset_specs, get_normalize_trans, eval,eval_mp , ImageFolderSubsampleTwins, cutmix
from torchvision import transforms      
from tqdm import tqdm       
import torchvision
import os 
import matplotlib.pyplot as plt 
import time 
from torch.optim.lr_scheduler import LambdaLR
import math 
import random
import multiprocessing  
from diffusers import  StableDiffusionXLImg2ImgPipeline
import torch.multiprocessing as mp 
from accelerate import PartialState 
import torch.distributed as dist

import warnings  
warnings.simplefilter("ignore")  

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   
print(device)  

def generate_slerp_interpolations(x, x_twin):
    batch_size, l_c, l_h, l_w = x.shape     

    alphas = torch.rand(batch_size, device=x.device).view(batch_size, 1, 1, 1)  # Shape: [batch_size, 1, 1, 1]

    v0 = x  
    v1 = x_twin 

    v0 = v0.reshape(batch_size, -1) 
    v1 = v1.reshape(batch_size, -1) 


    v0_norm = v0 / v0.norm(dim=1, keepdim=True)
    v1_norm = v1 / v1.norm(dim=1, keepdim=True) 

    v0 = v0.reshape(batch_size, l_c, l_h, l_w) 
    v1 = v1.reshape(batch_size, l_c, l_h, l_w) 

    dot = (v0_norm * v1_norm).sum(dim=1, keepdim=True).reshape(batch_size, 1, 1, 1)  

    theta_0 = torch.acos(dot)
    sin_theta_0 = torch.sin(theta_0)

    theta_t = theta_0 * alphas
    sin_theta_t = torch.sin(theta_t)
    s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
    s1 = sin_theta_t / sin_theta_0

    # interpolated_latents = s0 * v0 + s1 * v1  
    interpolated_latents = alphas * v0 + (1 - alphas) * v1 
    
    return interpolated_latents

def train_soft_lbls(args):    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)       
    torch.backends.cudnn.deterministic = True          
    random.seed(args.seed)      

    if args.teacher_arch == None:
        args.teacher_arch = args.arch

    model = get_network(args, args.arch, pretrained=False, data_parallel=True)    

    if args.resume_training == True:        
        model.load_state_dict(torch.load(args.model_ckpt_student,  map_location='cpu'))       
        print(f"Model loaded from {args.model_ckpt_student}")   


    

    num_gpus = torch.cuda.device_count()

    
    normalize, ____ = get_normalize_trans(args)   

    train_trans_later = transforms.Compose([transforms.RandomResizedCrop(args.input_size, scale=(args.rnd_res_scale_st, 1)), transforms.RandomHorizontalFlip(), normalize])            
    print(f'training trans size: {args.input_size}, rnd_res_scale_st: {args.rnd_res_scale_st}')     

    train_trans = transforms.Compose([transforms.ToTensor()])     
    
    # if args.subset.startswith('cifar') :
        
    # else:
        

    if args.subset.startswith('imagenet') :
        ds_tst = transforms.Compose([transforms.Resize(args.init_resize), transforms.CenterCrop(args.input_size), transforms.ToTensor(), normalize]) 
    else:
        ds_tst = transforms.Compose([transforms.ToTensor(), normalize]) 

    

    print(f'test trans initial resize: {args.init_resize}, input size: {args.input_size}')     
    
    syn_path = os.path.join(args.syn_data_path, f'lvl_{args.lvl}') if args.lvl is not None else args.syn_data_path  

    if 'train' in os.listdir(syn_path):     
            syn_path = os.path.join(syn_path, 'train')   

    if 'ipc_' in os.listdir(syn_path):          
        syn_path = os.path.join(syn_path, f'ipc_{args.ipc}')        

    # print(f'reading the training images from {syn_path}')   
    # print(os.listdir(syn_path))         


    pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")

    distributed_state = PartialState()
    pipe.to(distributed_state.device)
    # pipe.to(device)



        
    ds_train = ImageFolderSubsampleTwins(syn_path, samples_per_class=-1, transform=train_trans)     
    print('sub samples were loaded with sample per class:', args.ipc)    
    print(f'training images are loaded from {syn_path}')            
    
    if args.subset == 'cifar100':
        ds_tst = torchvision.datasets.CIFAR100(root=args.val_dir, train=False, download=True, transform=ds_tst)     
    elif args.subset == 'cifar10':
        ds_tst = torchvision.datasets.CIFAR10(root=args.val_dir, train=False, download=True, transform=ds_tst)  
    else:    
        ds_tst = torchvision.datasets.ImageFolder(os.path.join(args.val_dir, 'val'), transform=ds_tst)        


    print(len(ds_train), len(ds_tst))    

    train_dl = torch.utils.data.DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)       
    tst_dl = torch.utils.data.DataLoader(ds_tst, batch_size=128, shuffle=False, num_workers=args.workers)


    criterion = torch.nn.KLDivLoss(reduction="batchmean")
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd, betas=(0.9, 0.999))         
    sch = LambdaLR(optimizer, lambda step: 0.5 * (1.0 + math.cos(math.pi * step / args.epochs / 2)) if step <= args.epochs else 0, last_epoch=-1)

    best_acc = -1
    loss_ = []

    checkpoint_dir  = os.path.join('trained_checkpoints', args.extra_desc)

    class_folder_names = sorted([f for f in os.listdir(os.path.join(args.syn_data_path)) if f.startswith('.') == False])
    dataset_dict = torch.load(args.dataset_name_dict)
    class_names = [dataset_dict[folder_name].replace('_', ' ') for folder_name in class_folder_names]          
        
    print('.....................Training with soft labels.....................')

    resize_up = transforms.Resize(320)

    lbls = torch.load(args.lbl_file)        
    
    timesteps, num_inference_steps = pipe.get_timesteps(
        5,
        args.strength,
        'cuda',
        denoising_start= None,
    )
    latent_timestep = timesteps[:1].repeat(1 * 1)

    tail_ind = 0


    if args.init_acc_st and dist.get_rank() == 0:  
        acc = eval(model, tst_dl)    
        print(f"Initial Accuracy: {np.round(acc, 2)}")

    for epoch in range(args.epochs):   
        model.train()       
        t1 = time.time()     

        if dist.get_rank() == 0:
            print(f"Epoch: {epoch}")                

        generated = False   
        for b_ind, (x, x_twin, y) in tqdm(enumerate(train_dl), total=len(train_dl)):
            x_lq, y = x.to(device), y.to(device)
            x_twin = x_twin.to(device)      

            prompts = [f'a photo of a {class_names[int(y_.item())]}' for y_ in y]       
               
            x_lq = resize_up(x_lq)        
            x_twin = resize_up(x_twin)      
            
            in_dict = {'images': x_lq, 'prompts': prompts, 'labels': y, 'images_twin': x_twin}      
            with distributed_state.split_between_processes(in_dict) as dic:
                with torch.no_grad():    
                    
                    drop_val = np.random.uniform(0., 1, size=1)            
                    if drop_val > args.aug_drop_p:      
                        
                        x_gen = pipe.image_processor.preprocess(dic['images']).to('cuda')       
                        x_gen = pipe.prepare_latents(x_gen, latent_timestep, 1, 1, generator=None, add_noise=False, device="cuda", dtype=torch.float16) 

                        x_twin_gen = pipe.image_processor.preprocess(dic['images_twin']).to('cuda')     
                        x_twin_gen = pipe.prepare_latents(x_twin_gen, latent_timestep, 1, 1, generator=None, add_noise=False, device="cuda", dtype=torch.float16)   

                        x_gen = generate_slerp_interpolations(x_gen, x_twin_gen)    
                        x_gen = pipe(prompt=dic['prompts'], image=x_gen, num_inference_steps=5, strength=args.strength, guidance_scale=0., 
                                        num_images_per_prompt=1, output_type='pt').images
                        

                        # from PIL import Image       
                        # if generated == False and dist.get_rank() == 0:
                        #     generated = True    
                        #     for i in range(len(x_gen)): 
                        #         Image.fromarray(np.uint8(x_gen[i].cpu().numpy().transpose(1, 2, 0) *  255.)).resize((args.input_size,args.input_size)).save(f'gen_{i}_validate.png')       
                    
                        
                    else:
                        x_gen = dic['images']

                    local_y = dic['labels']     
                    

            gathered_results = [torch.zeros_like(x_gen) for _ in range(dist.get_world_size())]
            dist.all_gather(gathered_results, x_gen)

            gathered_labels = [torch.zeros_like(local_y) for _ in range(torch.distributed.get_world_size())]
            dist.all_gather(gathered_labels, local_y)

            x = torch.cat(gathered_results, dim=0)
            gathered_labels = torch.cat(gathered_labels, dim=0)
        
            x =  train_trans_later(x) 
            
            if args.cutmix == True:
                x = cutmix(x)  

    
            out = model(x)  
            out_soft = torch.log_softmax(out / args.temp, dim=1)        

            # soft_lbls

            soft_lbls = lbls[tail_ind:tail_ind + x.shape[0]]            
            soft_lbls = soft_lbls.to(out_soft.device)      

            tail_ind += x.shape[0]      

            loss = criterion(out_soft, soft_lbls)

            optimizer.zero_grad()   
            loss.backward() 
            optimizer.step()        
            loss_.append(loss.item())
        
        sch.step()  

        if dist.get_rank() == 0:
            if (epoch+1) % 1 == 0 :
                acc_tst = eval(model, tst_dl)       
                    
                if acc_tst > best_acc:      
                    best_acc = acc_tst

                    best_state_dict_cpu = {key: value.cpu() for key, value in model.state_dict().items()} 
                    checkpoint_save_name = f"model_{args.extra_desc}_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}_seed_{args.seed}.pth"     
                    torch.save(best_state_dict_cpu, checkpoint_save_name)     
                                                               

                print(f"Epoch: {epoch}, tst_acc: {np.round(acc_tst, 2)}, best_tst_acc: {np.round(best_acc, 2)}")        
        
            t2 = time.time()        
            elapsed_time = t2 - t1      
            print(f"Elapsed time: {elapsed_time:.2f} seconds") 
    
    #concatenate the best acc to the save name 
    if dist.get_rank() == 0:
        checkpoint_save_name = f"model_{args.extra_desc}_data_{args.subset}_arch_{args.arch}_ep_{args.epochs}_acc_{np.round(best_acc, 2)}_seed_{args.seed}.pth"     
        torch.save(best_state_dict_cpu, checkpoint_save_name)     

    print(f'final best acc: {best_acc}')
                    

    return best_acc  


if __name__ == "__main__":

    parser = argparse.ArgumentParser()        
    parser.add_argument("--cls_ind", type=int, default=None)    

    parser.add_argument("--subset", type=str, default="imagenette")
    parser.add_argument("--size", type=int, default=224)    
    parser.add_argument("--st_cls", type=int, default=0)    
    parser.add_argument("--end_cls", type=int, default=None)    
    parser.add_argument("--eval_ratio", type=float, default=0.)

    parser.add_argument("--extra_desc", type=str, default="")
    parser.add_argument("--model_ckpt", type=str, default=None)
    parser.add_argument("--model_ckpt_student", type=str, default=None)
    
    parser.add_argument("--nclass", type=int, default=1000)
    parser.add_argument("--classes", type=list)
    parser.add_argument("--init_resize", type=int, default=256)
    parser.add_argument("--input_size", type=int, default=None)
    parser.add_argument("--diff_input_size", type=int, default=None)
    
    parser.add_argument("--root_dir", type=str, default=None)   
    parser.add_argument("--val_dir", type=str, default=None)   
    parser.add_argument("--root_aug", type=str, default=None)   

    parser.add_argument("--emb_root_dir", type=str, default=None)   
    parser.add_argument("--emb_file_name", type=str, default=None)   

    parser.add_argument("--arch", type=str, default='resnet18')
    parser.add_argument("--teacher_arch", type=str, default=None)
    
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=200)  
    parser.add_argument("--lr", type=float, default=0.1)        
    parser.add_argument("--wd", type=float, default=1e-4)        
    parser.add_argument("--mom", type=float, default=.9)        
    parser.add_argument("--init_acc", action='store_true', default=False)
    parser.add_argument("--init_acc_st", action='store_true', default=False)

    parser.add_argument("--force_rded_net", action='store_true', default=False)
    parser.add_argument("--resume_training", action='store_true', default=False)

    parser.add_argument("--use_json_ds", action='store_true', default=False)
    parser.add_argument("--json_file_path", type=str, default=None)   

    parser.add_argument("--aug_drop_p", type=float, default=0.)
    
    
    parser.add_argument("--connection_string", type=str, default=None)   
    parser.add_argument("--container_name", type=str, default=None)  

    parser.add_argument("--lbl_file", type=str, required=True)      

    parser.add_argument("--soft_lbl", action='store_true', default=False)
    
    parser.add_argument("--workers", type=int, default=8)  
    
    parser.add_argument("--aug", action='store_true', default=False)
    parser.add_argument("--full_res", action='store_true', default=False)
    parser.add_argument("--cutmix", action='store_true', default=False)


    parser.add_argument("--dataset_name_dict", type=str) 
    parser.add_argument("--temp", type=float, default=1)        
    parser.add_argument("--syn_data_path", type=str, default=None)

    parser.add_argument("--strength", type=float, default=0.7)        
    parser.add_argument("--rnd_res_scale_st", type=float, default=0.8)        
    
    parser.add_argument("--lvl", type=int, default=None)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--exp_num", type=int, default=1)

    parser.add_argument("--ipc", type=int, default=None)

    args = parser.parse_args()  

    for key, value in vars(args).items():
        print(f"{key}: {value}")        
    
    set_dataset_specs(args)

    init_seed = args.seed
    accs_ = []  
    for exp_id in range(args.exp_num):
        args.seed = init_seed + exp_id      
        print(f'running experiment with seed {args.seed}')   
        acc = train_soft_lbls(args)
        accs_.append(acc)       

    print(f"Average acc: {np.round(np.mean(accs_), 2)} +- {np.round(np.std(accs_), 2)}")   

