import torch
import logging
from tqdm import tqdm
from utils import save_config, save_model2, get_logger
import os
from torch.cuda.amp.autocast_mode import autocast
from typing import List, Union, Literal
from contextlib import contextmanager
import random
from torch.func import jvp, vmap
import timeit
from jacobian import JacobianReg
import torchmetrics
from data import TwoCropTransform
from torchmetrics import MetricCollection, Accuracy, Precision, Recall, F1Score, MeanMetric
from PIL import Image



def dict2str(dic):
    return "\t".join([f"{k}={v:.4f}" for k,v in dic.items()])

class Hybrid():
    def __init__(self, encoder, projector, classifier, train_loss_fn, val_loss_fn, args=None, device=None):
        super().__init__()
        self.args = args
        self.save_hyperparameters()
        self.num_classes = self.args.num_classes
        self.epochs = self.args.epochs
        self.train_loss_fn = train_loss_fn
        self.val_loss_fn = val_loss_fn
        self.test_mode = self.args.test_mode

        if not self.test_mode:
            self.mylogger =  get_logger(logpath=f"{self.args.save_folder}/logs.log", displaying=False)
            print(f"logging file: {args.save_folder}/logs.log")
            self.train_logger = get_logger(logpath=f"{self.args.save_folder}/train.log", displaying=False)
            print(f"training logging file: {args.save_folder}/train.log")
        else:
            print(f"logging file: {args.result_path}")
            self.mylogger = get_logger(logpath=args.result_path, displaying=False)


        self.device = device

        self.encoder = encoder.to(self.device)
        self.projector = projector.to(self.device)
        self.classifier = classifier.to(self.device)

        # Setup metrics
        metrics = MetricCollection({
            'natural_accuracy': MeanMetric(),
            'loss': MeanMetric(),
            'loss_jac': MeanMetric()
        })

        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')
        self.test_metrics = metrics.clone(prefix='test_')

        # Move metrics to device
        self.train_metrics.to(device)
        self.val_metrics.to(device)
        self.test_metrics.to(device)


    @torch.enable_grad()
    @torch.inference_mode(False)

    def f(self, x):
        if self.args.method in ['cl', 'scl', 'hybrid', 'sl']:
            return self.projector(self.encoder(x))
        elif self.args.method in ['ae']:
            return self.encoder(x)
        else:
            raise ValueError(f"Unknown method: {self.args.method}")

    def f_stopBN(self, x):
        self.disable_batchnorm_running(self.encoder)
        self.disable_batchnorm_running(self.projector)
        y = self.f(x)
        self.enable_batchnorm_running(self.encoder)
        self.enable_batchnorm_running(self.projector)
        return y
    
    def jvp_fn(self, x, tangent):
        def f_wrapper(x):
            return self.f(x)
        self.disable_batchnorm_running(self.encoder)
        self.disable_batchnorm_running(self.projector)
        y, jv = jvp(f_wrapper, (x,), (tangent,))
        self.enable_batchnorm_running(self.encoder)
        self.enable_batchnorm_running(self.projector)
        return jv

    def run_jvp(self, x, S):

        def single_jvp(tangent):
            return self.jvp_fn(x, tangent)
        
        with torch.cuda.amp.autocast():
            Jv = vmap(single_jvp)(S)
        
        return Jv # return J·S
    
    def jac_reg_r(self, batch_x, num_proj=1):
        batch_size, channels, height, width = batch_x.shape
        with torch.no_grad():  
            S_proj = torch.randn(num_proj, batch_size, channels, height, width,
                            device=batch_x.device, dtype=torch.float32)
            # Merge CHW dimensions
            flat_vectors = S_proj.view(num_proj, batch_size, -1)  
            norms = torch.norm(flat_vectors, p=2, dim=2, keepdim=True)
            flat_vectors = flat_vectors / (norms + 1e-8)
            # Check normalized norms
            check_norms = torch.norm(flat_vectors, p=2, dim=2)

            # Ensure norms are close to 1
            assert torch.allclose(check_norms, torch.ones_like(check_norms), rtol=1e-3, atol=1e-3), \
                "Vectors are not properly normalized to unit length"
            
            S_proj = flat_vectors.view_as(S_proj)  

        with torch.cuda.amp.autocast(enabled=False):
            jv = self.run_jvp(batch_x, S_proj) 
            jv = jv.permute(1, 0, 2)  
            # Calculate the frobenius norm of jv
            jv_reshaped = jv.reshape(batch_size, -1) 
            frob_norm_squared = torch.sum(jv_reshaped ** 2, dim=1)  
            C = torch.tensor(channels * height * width / num_proj, 
                            dtype=torch.float32, 
                            device=batch_x.device)
            # C = 1.0
            J2 = C * frob_norm_squared.mean()
            R = 0.5 * J2
        return R
    
    def disable_batchnorm_running(self, model):
        for module in model.modules():
            if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
                module.eval()
    
    def enable_batchnorm_running(self, model):
        for module in model.modules():
            if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
                module.train()

    def jac_reg_sel(self, in_x, out_x, sel = 1):
        if sel == 1: 
            reg = JacobianReg()
            Jac_reg = reg(in_x, out_x)
        elif sel == 2:
            Jac_reg = self.jac_reg_r(in_x)

        return Jac_reg


    def training_step(self, epoch, batch):
        imgs, target = batch
        assert isinstance(imgs, (list, tuple)) and len(imgs) == 3
        img0, img1, img2 = imgs[0].to(self.device), imgs[1].to(self.device), imgs[2].to(self.device) 
        target = target.to(self.device)
        N = img0.size(0)
        Jac_reg = 0 

        requires_grad_dict = {
            'img0': False,
            'img1': False,
            'img2': False,
            'feature0': False,
            'feature1': False,
            'feature2': False
        }
        
        if self.args.method in ['cl', 'scl', 'hybrid']:  
            if self.args.jac_reg:
                requires_grad_dict['img1'] = True
                requires_grad_dict['img2'] = True
            
                if self.args.jac_reg_projector:
                    requires_grad_dict['feature1'] = True
                    requires_grad_dict['feature2'] = True

            img1.requires_grad = requires_grad_dict['img1']
            img2.requires_grad = requires_grad_dict['img2']

            feature1, feature2 = self.encoder(img1), self.encoder(img2)
            projection1, projection2 = self.projector(feature1), self.projector(feature2)
            
            if self.args.method == 'scl':
                loss_cl, loss_cl_dict = self.train_loss_fn(projection1, projection2, target)

            elif self.args.method == 'hybrid':
                pred1, pred2 = self.classifier(feature1), self.classifier(feature2)
                loss_cl, loss_cl_dict = self.train_loss_fn(projection1, projection2, pred1, pred2, target)
                
            elif self.args.method == 'cl':
                loss_cl, loss_cl_dict = self.train_loss_fn(projection1, projection2, labels=None)
            
            loss = loss_cl
            
            # Jacobian regularization
            if self.args.jac_reg:
                R1 = self.jac_reg_sel(img1, projection1, self.args.jac_sel)
                R2 = self.jac_reg_sel(img2, projection2, self.args.jac_sel)
                Jac_reg = self.args.lambda_JR * (R1 + R2)
                loss = loss_cl + Jac_reg 
            
                del R1, R2

        elif self.args.method in ['sl']:  
            if self.args.jac_reg:
                requires_grad_dict['img0'] = True
                if self.args.jac_reg_projector:
                    requires_grad_dict['feature0'] = True
            img0.requires_grad = requires_grad_dict['img0']        

            features0 = self.encoder(img0)
            pred0 = self.classifier(features0)
            loss_sl = self.train_loss_fn(pred0, target)
            loss = loss_sl

            if self.args.jac_reg:
                R0 = self.jac_reg_sel(img0, pred0, self.args.jac_sel)
                Jac_reg = self.args.lambda_JR * (R0)
                loss = loss + Jac_reg

        elif self.args.method in ['ae']:
            img0.requires_grad = True
            features = self.encoder(img0)
            reconstructed = self.projector(features)  
            loss_ae = self.train_loss_fn(reconstructed, img0)
            loss = loss_ae

            if self.args.jac_reg:
                R0 = self.jac_reg_sel(img0, features, self.args.jac_sel)
                Jac_reg = self.args.lambda_JR * (R0)
                loss = loss + Jac_reg

        features0 = self.encoder(img0)
        pred0 = self.classifier(features0)

        with torch.no_grad():
            instant_natural_acc = 0
            if self.args.method in ['scl', 'hybrid', 'sl', 'cl']:
                instant_natural_acc = (pred0.argmax(dim=1) == target).float().mean()
            elif self.args.method in ['ae']: 
                # reconstruction_error = torch.mean((reconstructed - imgs) ** 2).item()
                instant_natural_acc = - loss.item()

            # update metrics
            self.train_metrics['loss'](loss.item())
            self.train_metrics['natural_accuracy'](instant_natural_acc)

            if self.args.jac_reg:  
                Jac_reg = Jac_reg.item()
                self.train_metrics['loss_jac'](Jac_reg)

        return loss, instant_natural_acc, Jac_reg

    def validation_step(self, epoch, batch):
        imgs, target = batch
        assert isinstance(imgs, (list, tuple)) and len(imgs) == 3
        img0, img1, img2 = imgs[0].to(self.device), imgs[1].to(self.device), imgs[2].to(self.device) 
        target = target.to(self.device)



        # Update metrics
        with torch.no_grad():
            if self.args.method in ['scl', 'hybrid', 'sl']:
                
                features = self.encoder(img0)
                predictions = self.classifier(features)
                loss = self.val_loss_fn(predictions, target)
    
                instant_natural_acc = (predictions.argmax(dim=1) == target).float().mean()

            elif self.args.method in ['ae']:
                # self-supervised learning
                features = self.encoder(img0)
                reconstructed = self.projector(features)
                loss = self.val_loss_fn(reconstructed, img0)
                instant_natural_acc = - loss.item()

            elif  self.args.method in ['cl']:
                
                feature1, feature2 = self.encoder(img1), self.encoder(img2)
                projection1, projection2 = self.projector(feature1), self.projector(feature2)
                loss, loss_cl_dict = self.val_loss_fn(projection1, projection2, labels=None)
                instant_natural_acc = - loss.item()
            else:
                raise ValueError(f"Unknown method: {self.args.method}")
            
            # update metrics
            self.val_metrics['loss'](loss.item())
            self.val_metrics['natural_accuracy'].update(instant_natural_acc)
                
        return loss, instant_natural_acc

    def train(self, train_dataloader, val_dataloader):
        save_config(self.args, self.args.save_folder)
        self.mylogger.info(self.encoder)
        self.mylogger.info(self.projector)
        self.mylogger.info(self.classifier)
        self.mylogger.info(self.args)
    
        best_val_acc = float('-inf')
        best_epoch = 0 


        if self.args.method in ['cl', 'scl', 'hybrid', 'sl', 'ae']:
            parameters = (
                list(self.encoder.parameters()) + 
                list(self.projector.parameters()) + 
                list(self.classifier.parameters())
            )
        else:
            raise ValueError(f"Unknown method: {self.args.method}")

        optimizer = torch.optim.Adam(
            parameters, 
            lr=self.args.lr, 
            weight_decay=self.args.wd 
        )
        
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.epochs, eta_min=1e-6  
        )

        for epoch in range(1, self.epochs + 1):
            self.train_metrics.reset()
            self.val_metrics.reset()

            # Training phase
            self.encoder.train()
            pbar = tqdm(train_dataloader, total=len(train_dataloader))
            
            for batch_idx, batch in enumerate(pbar):
                pbar.set_description(f"Train {self.args.method}: epoch [{epoch}/{self.epochs}] [{batch_idx}/{len(train_dataloader)}]")
                
      
                optimizer.zero_grad()
                train_loss, instant_natural_acc, Jac_reg = self.training_step(epoch, batch)
                train_loss.backward()
                optimizer.step()

                
                pbar.set_postfix({
                    "train_loss": f"{train_loss.item():.4f}",
                    "train_acc": f"{instant_natural_acc:.4f}",
                    "Jac_reg": f"{Jac_reg:.4f}"
                })


            lr_scheduler.step()
            train_metrics = self.train_metrics.compute()

            ## Validation phase
            self.encoder.eval()
            total_val_samples = len(val_dataloader.dataset)
            batch_size = val_dataloader.batch_size
            num_batches = len(val_dataloader)
            self.mylogger.info(f"Validation set size: {total_val_samples}, "
                            f"batch size: {batch_size}, "
                            f"total batches: {num_batches}")
            pbar = tqdm(val_dataloader, total=len(val_dataloader))
            
            for batch_idx, batch in enumerate(pbar):
                pbar.set_description(f"Val: epoch [{epoch}/{self.epochs}] [{batch_idx}/{len(val_dataloader)}]")
                
                val_loss, instant_val_acc = self.validation_step(epoch, batch)

                pbar.set_postfix({
                    "val_loss": f"{val_loss.item():.4f}", 
                    "val_acc": f"{instant_val_acc:.4f}"
                })

                pbar.update(1)

            # End of validation epoch
            val_metrics = self.val_metrics.compute()
            self.train_logger.info(
                f"Epoch {epoch} Train metrics: {dict2str(train_metrics)}"
                f"Detailed validation metrics: {dict2str(val_metrics)}"
            )
            self.mylogger.info(
                f"Epoch {epoch} |\n"
                f"Train: {dict2str(train_metrics)} |\n"
                f"Val: {dict2str(val_metrics)} |\n"
                f"LR: {lr_scheduler.get_last_lr()[0]:.6f}"
            )
            print(dict2str(train_metrics))
            print(dict2str(val_metrics))

            if best_val_acc < val_metrics['val_natural_accuracy']:
                best_val_acc = val_metrics['val_natural_accuracy']
                best_epoch=epoch
                save_model2({"encoder": self.encoder, "projector": self.projector, "classifier": self.classifier},
                    optimizer, self.args, epoch,
                    os.path.join(self.args.save_folder, f'best_model.pth'))
                
            elif epoch % self.args.save_freq == 0:
                save_file = os.path.join(self.args.save_folder, f'epoch_{epoch:04d}_model.pth')
                save_model2({"encoder":self.encoder, "projector":self.projector, "classifier":self.classifier}, 
                        optimizer, self.args, 
                        epoch, save_file)
                
        save_model2({"encoder":self.encoder,"projector":self.projector,"classifier":self.classifier}, 
                optimizer, self.args, 
                epoch, os.path.join(self.args.save_folder, f'last_model.pth'))
        
        self.mylogger.info(f"Training finished. Best acc: {format(best_val_acc, '.4f')} at epoch {best_epoch}")

    def save_hyperparameters(self):
        """Save important hyperparameters."""
        self.hparams = {
            'num_classes': self.args.num_classes,
            'epochs': self.args.epochs,
            'method': self.args.method,
            'jac_reg': self.args.jac_reg,
            'lambda_JR': self.args.lambda_JR if hasattr(self.args, 'lambda_JR') else None
        }